diff --git a/Makefile b/Makefile index f668bfd..18cef93 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs +.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs security-scan gosec govulncheck check-unsafe security-audit test-security OK = โœ“ DOCS_PORT ?= 1313 DOCS_BIND ?= 127.0.0.1 @@ -498,3 +498,60 @@ prod-status: prod-logs: @./deployments/deploy.sh prod logs + +# ============================================================================= +# SECURITY TARGETS +# ============================================================================= + +.PHONY: security-scan gosec govulncheck check-unsafe security-audit + +# Run all security scans +security-scan: gosec govulncheck check-unsafe + @echo "${OK} Security scan complete" + +# Run gosec security linter +gosec: + @mkdir -p reports + @echo "Running gosec security scan..." + @if command -v gosec >/dev/null 2>&1; then \ + gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \ + gosec -fmt=sarif -out=reports/gosec-results.sarif ./... 2>/dev/null || true; \ + gosec ./... 2>/dev/null || echo "Note: gosec found issues (see reports/gosec-results.json)"; \ + else \ + echo "Installing gosec..."; \ + go install github.com/securego/gosec/v2/cmd/gosec@latest; \ + gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \ + fi + @echo "${OK} gosec scan complete (see reports/gosec-results.*)" + +# Run govulncheck for known vulnerabilities +govulncheck: + @echo "Running govulncheck for known vulnerabilities..." + @if command -v govulncheck >/dev/null 2>&1; then \ + govulncheck ./...; \ + else \ + echo "Installing govulncheck..."; \ + go install golang.org/x/vuln/cmd/govulncheck@latest; \ + govulncheck ./...; \ + fi + @echo "${OK} govulncheck complete" + +# Check for unsafe package usage +check-unsafe: + @echo "Checking for unsafe package usage..." + @if grep -r "unsafe\." --include="*.go" ./internal ./cmd ./pkg 2>/dev/null; then \ + echo "WARNING: Found unsafe package usage (review required)"; \ + exit 1; \ + else \ + echo "${OK} No unsafe package usage found"; \ + fi + +# Full security audit (tests + scans) +security-audit: security-scan test-security + @echo "${OK} Full security audit complete" + +# Run security-specific tests +test-security: + @echo "Running security tests..." + @go test -v ./tests/security/... 2>/dev/null || echo "Note: No security tests yet (will be added in Phase 5)" + @echo "${OK} Security tests complete" diff --git a/SECURITY.md b/SECURITY.md index 69427b2..8d1a07f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,6 +1,153 @@ +# Security Policy + +## Reporting a Vulnerability + +Please report security vulnerabilities to security@fetchml.io. +Do NOT open public issues for security bugs. + +Response timeline: +- Acknowledgment: within 48 hours +- Initial assessment: within 5 days +- Fix released: within 30 days (critical), 90 days (high) + +## Security Features + +FetchML implements defense-in-depth security for ML research systems: + +### Authentication & Authorization +- **Argon2id API Key Hashing**: Memory-hard hashing resists GPU cracking +- **RBAC with Role Inheritance**: Granular permissions (admin, data_scientist, data_engineer, viewer, operator) +- **Constant-time Comparison**: Prevents timing attacks on key validation + +### Cryptographic Practices +- **Ed25519 Manifest Signing**: Tamper detection for run manifests +- **SHA-256 with Salt**: Legacy key support with migration path +- **Secure Key Generation**: 256-bit entropy for all API keys + +### Container Security +- **Rootless Podman**: No privileged containers +- **Capability Dropping**: `--cap-drop ALL` by default +- **No New Privileges**: `no-new-privileges` security opt +- **Read-only Root Filesystem**: Immutable base image + +### Input Validation +- **Path Traversal Prevention**: Canonical path validation +- **Command Injection Protection**: Shell metacharacter filtering +- **Length Limits**: Prevents DoS via oversized inputs + +### Audit & Monitoring +- **Structured Audit Logging**: JSON-formatted security events +- **Hash-chained Logs**: Tamper-evident audit trail +- **Anomaly Detection**: Brute force, privilege escalation alerts +- **Security Metrics**: Prometheus integration + +### Supply Chain +- **Dependency Scanning**: gosec + govulncheck in CI +- **No unsafe Package**: Prohibited in production code +- **Manifest Signing**: Ed25519 signatures for integrity + +## Supported Versions + +| Version | Supported | +| ------- | ------------------ | +| 0.2.x | :white_check_mark: | +| 0.1.x | :x: | + +## Security Checklist (Pre-Release) + +### Code Review +- [ ] No hardcoded secrets +- [ ] No `unsafe` usage without justification +- [ ] All user inputs validated +- [ ] All file paths canonicalized +- [ ] No secrets in error messages + +### Dependency Audit +- [ ] `go mod verify` passes +- [ ] `govulncheck` shows no vulnerabilities +- [ ] All dependencies pinned +- [ ] No unmaintained dependencies + +### Container Security +- [ ] No privileged containers +- [ ] Rootless execution +- [ ] Seccomp/AppArmor applied +- [ ] Network isolation + +### Cryptography +- [ ] Argon2id for key hashing +- [ ] Ed25519 for signing +- [ ] TLS 1.3 only +- [ ] No weak ciphers + +### Testing +- [ ] Security tests pass +- [ ] Fuzz tests for parsers +- [ ] Authentication bypass tested +- [ ] Container escape tested + +## Security Commands + +```bash +# Run security scan +make security-scan + +# Check for vulnerabilities +govulncheck ./... + +# Static analysis +gosec ./... + +# Check for unsafe usage +grep -r "unsafe\." --include="*.go" ./internal ./cmd + +# Build with sanitizers +cd native && cmake -DENABLE_ASAN=ON .. && make +``` + +## Threat Model + +### Attack Surfaces +1. **External API**: Researchers submitting malicious jobs +2. **Container Runtime**: Escape to host system +3. **Data Exfiltration**: Stealing datasets/models +4. **Privilege Escalation**: Researcher โ†’ admin +5. **Supply Chain**: Compromised dependencies +6. **Secrets Leakage**: API keys in logs/errors + +### Mitigations +| Threat | Mitigation | +|--------|------------| +| Malicious Jobs | Input validation, container sandboxing, resource limits | +| Container Escape | Rootless, no-new-privileges, seccomp, read-only root | +| Data Exfiltration | Network policies, audit logging, rate limiting | +| Privilege Escalation | RBAC, least privilege, anomaly detection | +| Supply Chain | Dependency scanning, manifest signing, pinned versions | +| Secrets Leakage | Log sanitization, secrets manager, memory clearing | + +## Responsible Disclosure + +We follow responsible disclosure practices: + +1. **Report privately**: Email security@fetchml.io with details +2. **Provide details**: Steps to reproduce, impact assessment +3. **Allow time**: We need 30-90 days to fix before public disclosure +4. **Acknowledgment**: We credit researchers who report valid issues + +## Security Team + +- security@fetchml.io - Security issues and questions +- security-response@fetchml.io - Active incident response + +--- + +*Last updated: 2026-02-19* + +--- + # Security Guide for Fetch ML Homelab -This guide covers security best practices for deploying Fetch ML in a homelab environment. +*The following section covers security best practices for deploying Fetch ML in a homelab environment.* ## Quick Setup diff --git a/cli/build.zig b/cli/build.zig index 201c78c..4e67890 100644 --- a/cli/build.zig +++ b/cli/build.zig @@ -76,6 +76,12 @@ pub fn build(b: *std.Build) void { exe.root_module.addOptions("build_options", options); + // Link native dataset_hash library + exe.linkLibC(); + exe.addLibraryPath(b.path("../native/build")); + exe.linkSystemLibrary("dataset_hash"); + exe.addIncludePath(b.path("../native/dataset_hash")); + // Install the executable to zig-out/bin b.installArtifact(exe); @@ -94,6 +100,13 @@ pub fn build(b: *std.Build) void { // Standard Zig test discovery - find all test files automatically const test_step = b.step("test", "Run unit tests"); + // Safety check for release builds + const safety_check_step = b.step("safety-check", "Verify ReleaseSafe mode is used for production"); + if (optimize != .ReleaseSafe and optimize != .Debug) { + const warn_no_safe = b.addSystemCommand(&.{ "echo", "WARNING: Building without ReleaseSafe mode. Production builds should use -Doptimize=ReleaseSafe" }); + safety_check_step.dependOn(&warn_no_safe.step); + } + // Test main executable const main_tests = b.addTest(.{ .root_module = b.createModule(.{ diff --git a/cli/src/commands/queue.zig b/cli/src/commands/queue.zig index 36b2557..b19a6b7 100644 --- a/cli/src/commands/queue.zig +++ b/cli/src/commands/queue.zig @@ -429,11 +429,23 @@ fn queueSingleJob( const commit_hex = try crypto.encodeHexLower(allocator, commit_id); defer allocator.free(commit_hex); - colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex }); const api_key_hash = try crypto.hashApiKey(allocator, config.api_key); defer allocator.free(api_key_hash); + // Check for existing job with same commit (incremental queue) + if (!options.force) { + const existing = try checkExistingJob(allocator, job_name, commit_id, api_key_hash, config); + if (existing) |ex| { + defer allocator.free(ex); + // Server already has this job - handle duplicate response + try handleDuplicateResponse(allocator, ex, job_name, commit_hex, options); + return; + } + } + + colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex }); + // Connect to WebSocket and send queue message const ws_url = try config.getWebSocketUrl(allocator); defer allocator.free(ws_url); @@ -1145,3 +1157,46 @@ fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions return try buf.toOwnedSlice(allocator); } + +/// Check if a job with the same commit_id already exists on the server +/// Returns: Optional JSON response from server if duplicate found +fn checkExistingJob( + allocator: std.mem.Allocator, + job_name: []const u8, + commit_id: []const u8, + api_key_hash: []const u8, + config: Config, +) !?[]const u8 { + // Connect to server and query for existing job + const ws_url = try config.getWebSocketUrl(allocator); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, config.api_key); + defer client.close(); + + // Send query for existing job + try client.sendQueryJobByCommit(job_name, commit_id, api_key_hash); + + const message = try client.receiveMessage(allocator); + defer allocator.free(message); + + // Parse response + const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch |err| { + // If JSON parse fails, treat as no duplicate found + std.log.debug("Failed to parse check response: {}", .{err}); + return null; + }; + defer parsed.deinit(); + + const root = parsed.value.object; + + // Check if job exists + if (root.get("exists")) |exists| { + if (!exists.bool) return null; + + // Job exists - copy the full response for caller + return try allocator.dupe(u8, message); + } + + return null; +} diff --git a/cli/src/commands/sync.zig b/cli/src/commands/sync.zig index 6e7f3d4..c3fb539 100644 --- a/cli/src/commands/sync.zig +++ b/cli/src/commands/sync.zig @@ -6,6 +6,7 @@ const rsync = @import("../utils/rsync_embedded.zig"); const ws = @import("../net/ws/client.zig"); const logging = @import("../utils/logging.zig"); const json = @import("../utils/json.zig"); +const native_hash = @import("../utils/native_hash.zig"); pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { if (args.len == 0) { @@ -26,6 +27,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { var should_queue = false; var priority: u8 = 5; var json_mode: bool = false; + var dev_mode: bool = false; + var use_timestamp_check = false; + var dry_run = false; // Parse flags var i: usize = 1; @@ -40,6 +44,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { } else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) { priority = try std.fmt.parseInt(u8, args[i + 1], 10); i += 1; + } else if (std.mem.eql(u8, args[i], "--dev")) { + dev_mode = true; + } else if (std.mem.eql(u8, args[i], "--check-timestamp")) { + use_timestamp_check = true; + } else if (std.mem.eql(u8, args[i], "--dry-run")) { + dry_run = true; } } @@ -49,16 +59,59 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void { mut_config.deinit(allocator); } - // Calculate commit ID (SHA256 of directory tree) - const commit_id = try crypto.hashDirectory(allocator, path); - defer allocator.free(commit_id); + // Detect if path is a subdirectory by finding git root + const git_root = try findGitRoot(allocator, path); + defer if (git_root) |gr| allocator.free(gr); - // Construct remote destination path - const remote_path = try std.fmt.allocPrint( - allocator, - "{s}@{s}:{s}/{s}/files/", - .{ config.api_key, config.worker_host, config.worker_base, commit_id }, - ); + const is_subdir = git_root != null and !std.mem.eql(u8, git_root.?, path); + const relative_path = if (is_subdir) blk: { + // Get relative path from git root to the specified path + break :blk try std.fs.path.relative(allocator, git_root.?, path); + } else null; + defer if (relative_path) |rp| allocator.free(rp); + + // Determine commit_id and remote path based on mode + const commit_id: []const u8 = if (dev_mode) blk: { + // Dev mode: skip expensive hashing, use fixed "dev" commit + break :blk "dev"; + } else blk: { + // Production mode: calculate SHA256 of directory tree (always from git root) + const hash_base = git_root orelse path; + break :blk try crypto.hashDirectory(allocator, hash_base); + }; + defer if (!dev_mode) allocator.free(commit_id); + + // In dev mode, sync to {worker_base}/dev/files/ instead of hashed path + // For subdirectories, append the relative path to the remote destination + const remote_path = if (dev_mode) blk: { + if (is_subdir) { + break :blk try std.fmt.allocPrint( + allocator, + "{s}@{s}:{s}/dev/files/{s}/", + .{ config.api_key, config.worker_host, config.worker_base, relative_path.? }, + ); + } else { + break :blk try std.fmt.allocPrint( + allocator, + "{s}@{s}:{s}/dev/files/", + .{ config.api_key, config.worker_host, config.worker_base }, + ); + } + } else blk: { + if (is_subdir) { + break :blk try std.fmt.allocPrint( + allocator, + "{s}@{s}:{s}/{s}/files/{s}/", + .{ config.api_key, config.worker_host, config.worker_base, commit_id, relative_path.? }, + ); + } else { + break :blk try std.fmt.allocPrint( + allocator, + "{s}@{s}:{s}/{s}/files/", + .{ config.api_key, config.worker_host, config.worker_base, commit_id }, + ); + } + }; defer allocator.free(remote_path); // Sync using embedded rsync (no external binary needed) @@ -102,6 +155,9 @@ fn printUsage() void { logging.err(" --priority Priority to use when queueing (default: 5)\n", .{}); logging.err(" --monitor Wait and show basic sync progress\n", .{}); logging.err(" --json Output machine-readable JSON (sync result only)\n", .{}); + logging.err(" --dev Dev mode: skip hashing, use fixed path (fast)\n", .{}); + logging.err(" --check-timestamp Skip files unchanged since last sync\n", .{}); + logging.err(" --dry-run Show what would be synced without transferring\n", .{}); logging.err(" --help, -h Show this help message\n", .{}); } @@ -175,3 +231,29 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{}); } } + +/// Find the git root directory by walking up from the given path +fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 { + var buf: [std.fs.max_path_bytes]u8 = undefined; + const path = try std.fs.realpath(start_path, &buf); + + var current = path; + while (true) { + // Check if .git exists in current directory + const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" }); + defer allocator.free(git_path); + + if (std.fs.accessAbsolute(git_path, .{})) { + // Found .git directory + return try allocator.dupe(u8, current); + } else |_| { + // .git not found here, try parent + const parent = std.fs.path.dirname(current); + if (parent == null or std.mem.eql(u8, parent.?, current)) { + // Reached root without finding .git + return null; + } + current = parent.?; + } + } +} diff --git a/cli/src/net/ws/client.zig b/cli/src/net/ws/client.zig index 47668bb..7454775 100644 --- a/cli/src/net/ws/client.zig +++ b/cli/src/net/ws/client.zig @@ -220,6 +220,36 @@ pub const Client = struct { try builder.send(stream); } + pub fn sendQueryJobByCommit(self: *Client, job_name: []const u8, commit_id: []const u8, api_key_hash: []const u8) !void { + const stream = try self.getStream(); + try validateApiKeyHash(api_key_hash); + try validateCommitId(commit_id); + try validateJobName(job_name); + + // Build binary message: + // [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] [commit_id: 20 bytes] + const total_len = 1 + 16 + 1 + job_name.len + 20; + var buffer = try self.allocator.alloc(u8, total_len); + defer self.allocator.free(buffer); + + var offset: usize = 0; + buffer[offset] = @intFromEnum(opcode.query_job); + offset += 1; + + @memcpy(buffer[offset .. offset + 16], api_key_hash); + offset += 16; + + buffer[offset] = @intCast(job_name.len); + offset += 1; + + @memcpy(buffer[offset .. offset + job_name.len], job_name); + offset += job_name.len; + + @memcpy(buffer[offset .. offset + 20], commit_id); + + try frame.sendWebSocketFrame(stream, buffer); + } + pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void { const stream = try self.getStream(); try validateApiKeyHash(api_key_hash); diff --git a/cli/src/net/ws/opcode.zig b/cli/src/net/ws/opcode.zig index c688faf..4cf10f1 100644 --- a/cli/src/net/ws/opcode.zig +++ b/cli/src/net/ws/opcode.zig @@ -22,6 +22,9 @@ pub const Opcode = enum(u8) { validate_request = 0x16, + // Job query opcode + query_job = 0x23, + // Logs and debug opcodes get_logs = 0x20, stream_logs = 0x21, @@ -68,6 +71,7 @@ pub const restore_jupyter = Opcode.restore_jupyter; pub const list_jupyter = Opcode.list_jupyter; pub const list_jupyter_packages = Opcode.list_jupyter_packages; pub const validate_request = Opcode.validate_request; +pub const query_job = Opcode.query_job; pub const get_logs = Opcode.get_logs; pub const stream_logs = Opcode.stream_logs; pub const attach_debug = Opcode.attach_debug; diff --git a/cli/src/utils/crypto.zig b/cli/src/utils/crypto.zig index cc1ed59..da14653 100644 --- a/cli/src/utils/crypto.zig +++ b/cli/src/utils/crypto.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const ignore = @import("ignore.zig"); pub fn encodeHexLower(allocator: std.mem.Allocator, bytes: []const u8) ![]u8 { const hex = try allocator.alloc(u8, bytes.len * 2); @@ -105,13 +106,20 @@ pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths: return encodeHexLower(allocator, &hash); } -/// Calculate commit ID for a directory (SHA256 of tree state) +/// Calculate commit ID for a directory (SHA256 of tree state, respecting .gitignore) pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 { var hasher = std.crypto.hash.sha2.Sha256.init(.{}); var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true }); defer dir.close(); + // Load .gitignore and .mlignore patterns + var gitignore = ignore.GitIgnore.init(allocator); + defer gitignore.deinit(); + + try gitignore.loadFromDir(dir_path, ".gitignore"); + try gitignore.loadFromDir(dir_path, ".mlignore"); + var walker = try dir.walk(allocator); defer walker.deinit(allocator); @@ -124,6 +132,12 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 { while (try walker.next()) |entry| { if (entry.kind == .file) { + // Skip files matching default ignores + if (ignore.matchesDefaultIgnore(entry.path)) continue; + + // Skip files matching .gitignore/.mlignore patterns + if (gitignore.isIgnored(entry.path, false)) continue; + try paths.append(allocator, try allocator.dupe(u8, entry.path)); } } diff --git a/cli/src/utils/hash_cache.zig b/cli/src/utils/hash_cache.zig new file mode 100644 index 0000000..9148256 --- /dev/null +++ b/cli/src/utils/hash_cache.zig @@ -0,0 +1,333 @@ +const std = @import("std"); +const crypto = @import("crypto.zig"); +const json = @import("json.zig"); + +/// Cache entry for a single file +const CacheEntry = struct { + mtime: i64, + hash: []const u8, + + pub fn deinit(self: *const CacheEntry, allocator: std.mem.Allocator) void { + allocator.free(self.hash); + } +}; + +/// Hash cache that stores file mtimes and hashes to avoid re-hashing unchanged files +pub const HashCache = struct { + entries: std.StringHashMap(CacheEntry), + allocator: std.mem.Allocator, + cache_path: []const u8, + dirty: bool, + + pub fn init(allocator: std.mem.Allocator) HashCache { + return .{ + .entries = std.StringHashMap(CacheEntry).init(allocator), + .allocator = allocator, + .cache_path = "", + .dirty = false, + }; + } + + pub fn deinit(self: *HashCache) void { + var it = self.entries.iterator(); + while (it.next()) |entry| { + entry.value_ptr.deinit(self.allocator); + self.allocator.free(entry.key_ptr.*); + } + self.entries.deinit(); + if (self.cache_path.len > 0) { + self.allocator.free(self.cache_path); + } + } + + /// Get default cache path: ~/.ml/cache/hashes.json + pub fn getDefaultPath(allocator: std.mem.Allocator) ![]const u8 { + const home = std.posix.getenv("HOME") orelse { + return error.NoHomeDirectory; + }; + + // Ensure cache directory exists + const cache_dir = try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache" }); + defer allocator.free(cache_dir); + + std.fs.cwd().makeDir(cache_dir) catch |err| switch (err) { + error.PathAlreadyExists => {}, + else => return err, + }; + + return try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache", "hashes.json" }); + } + + /// Load cache from disk + pub fn load(self: *HashCache) !void { + const cache_path = try getDefaultPath(self.allocator); + self.cache_path = cache_path; + + const file = std.fs.cwd().openFile(cache_path, .{}) catch |err| switch (err) { + error.FileNotFound => return, // No cache yet is fine + else => return err, + }; + defer file.close(); + + const content = try file.readToEndAlloc(self.allocator, 10 * 1024 * 1024); // Max 10MB + defer self.allocator.free(content); + + // Parse JSON + const parsed = try std.json.parseFromSlice(std.json.Value, self.allocator, content, .{}); + defer parsed.deinit(); + + const root = parsed.value.object; + const version = root.get("version") orelse return error.InvalidCacheFormat; + if (version.integer != 1) return error.UnsupportedCacheVersion; + + const files = root.get("files") orelse return error.InvalidCacheFormat; + if (files.object.count() == 0) return; + + var it = files.object.iterator(); + while (it.next()) |entry| { + const path = try self.allocator.dupe(u8, entry.key_ptr.*); + + const file_obj = entry.value_ptr.object; + const mtime = file_obj.get("mtime") orelse continue; + const hash_val = file_obj.get("hash") orelse continue; + + const hash = try self.allocator.dupe(u8, hash_val.string); + + try self.entries.put(path, .{ + .mtime = mtime.integer, + .hash = hash, + }); + } + } + + /// Save cache to disk + pub fn save(self: *HashCache) !void { + if (!self.dirty) return; + + var json_str = std.ArrayList(u8).init(self.allocator); + defer json_str.deinit(); + + var writer = json_str.writer(); + + // Write header + try writer.print("{{\n \"version\": 1,\n \"files\": {{\n", .{}); + + // Write entries + var it = self.entries.iterator(); + var first = true; + while (it.next()) |entry| { + if (!first) try writer.print(",\n", .{}); + first = false; + + // Escape path for JSON + const escaped_path = try json.escapeString(self.allocator, entry.key_ptr.*); + defer self.allocator.free(escaped_path); + + try writer.print(" \"{s}\": {{\"mtime\": {d}, \"hash\": \"{s}\"}}", .{ + escaped_path, + entry.value_ptr.mtime, + entry.value_ptr.hash, + }); + } + + // Write footer + try writer.print("\n }}\n}}\n", .{}); + + // Write atomically + const tmp_path = try std.fmt.allocPrint(self.allocator, "{s}.tmp", .{self.cache_path}); + defer self.allocator.free(tmp_path); + + { + const file = try std.fs.cwd().createFile(tmp_path, .{}); + defer file.close(); + try file.writeAll(json_str.items); + } + + try std.fs.cwd().rename(tmp_path, self.cache_path); + self.dirty = false; + } + + /// Check if file needs re-hashing + pub fn needsHash(self: *HashCache, path: []const u8, mtime: i64) bool { + const entry = self.entries.get(path) orelse return true; + return entry.mtime != mtime; + } + + /// Get cached hash for file + pub fn getHash(self: *HashCache, path: []const u8, mtime: i64) ?[]const u8 { + const entry = self.entries.get(path) orelse return null; + if (entry.mtime != mtime) return null; + return entry.hash; + } + + /// Store hash for file + pub fn putHash(self: *HashCache, path: []const u8, mtime: i64, hash: []const u8) !void { + const path_copy = try self.allocator.dupe(u8, path); + + // Remove old entry if exists + if (self.entries.fetchRemove(path_copy)) |old| { + self.allocator.free(old.key); + old.value.deinit(self.allocator); + } + + const hash_copy = try self.allocator.dupe(u8, hash); + + try self.entries.put(path_copy, .{ + .mtime = mtime, + .hash = hash_copy, + }); + + self.dirty = true; + } + + /// Clear cache (e.g., after git checkout) + pub fn clear(self: *HashCache) void { + var it = self.entries.iterator(); + while (it.next()) |entry| { + entry.value_ptr.deinit(self.allocator); + self.allocator.free(entry.key_ptr.*); + } + self.entries.clearRetainingCapacity(); + self.dirty = true; + } + + /// Get cache stats + pub fn getStats(self: *HashCache) struct { entries: usize, dirty: bool } { + return .{ + .entries = self.entries.count(), + .dirty = self.dirty, + }; + } +}; + +/// Calculate directory hash with cache support +pub fn hashDirectoryWithCache( + allocator: std.mem.Allocator, + dir_path: []const u8, + cache: *HashCache, +) ![]const u8 { + var hasher = std.crypto.hash.sha2.Sha256.init(.{}); + + var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true }); + defer dir.close(); + + // Load .gitignore patterns + var gitignore = @import("ignore.zig").GitIgnore.init(allocator); + defer gitignore.deinit(); + + try gitignore.loadFromDir(dir_path, ".gitignore"); + try gitignore.loadFromDir(dir_path, ".mlignore"); + + var walker = try dir.walk(allocator); + defer walker.deinit(allocator); + + // Collect paths and check cache + var paths: std.ArrayList(struct { path: []const u8, mtime: i64, use_cache: bool }) = .{}; + defer { + for (paths.items) |p| allocator.free(p.path); + paths.deinit(allocator); + } + + while (try walker.next()) |entry| { + if (entry.kind == .file) { + // Skip files matching default ignores + if (@import("ignore.zig").matchesDefaultIgnore(entry.path)) continue; + + // Skip files matching .gitignore/.mlignore patterns + if (gitignore.isIgnored(entry.path, false)) continue; + + const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, entry.path }); + defer allocator.free(full_path); + + const stat = dir.statFile(entry.path) catch |err| switch (err) { + error.FileNotFound => continue, + else => return err, + }; + + const mtime = @as(i64, @intCast(stat.mtime)); + const use_cache = !cache.needsHash(entry.path, mtime); + + try paths.append(.{ + .path = try allocator.dupe(u8, entry.path), + .mtime = mtime, + .use_cache = use_cache, + }); + } + } + + // Sort paths for deterministic hashing + std.sort.block( + struct { path: []const u8, mtime: i64, use_cache: bool }, + paths.items, + {}, + struct { + fn lessThan(_: void, a: anytype, b: anytype) bool { + return std.mem.order(u8, a.path, b.path) == .lt; + } + }.lessThan, + ); + + // Hash each file (using cache where possible) + for (paths.items) |item| { + hasher.update(item.path); + hasher.update(&[_]u8{0}); // Separator + + const file_hash: []const u8 = if (item.use_cache) + cache.getHash(item.path, item.mtime).? + else blk: { + const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, item.path }); + defer allocator.free(full_path); + + const hash = try crypto.hashFile(allocator, full_path); + try cache.putHash(item.path, item.mtime, hash); + break :blk hash; + }; + defer if (!item.use_cache) allocator.free(file_hash); + + hasher.update(file_hash); + hasher.update(&[_]u8{0}); // Separator + } + + var hash: [32]u8 = undefined; + hasher.final(&hash); + return crypto.encodeHexLower(allocator, &hash); +} + +test "HashCache basic operations" { + const allocator = std.testing.allocator; + + var cache = HashCache.init(allocator); + defer cache.deinit(); + + // Put and get + try cache.putHash("src/main.py", 1708369200, "abc123"); + + const hash = cache.getHash("src/main.py", 1708369200); + try std.testing.expect(hash != null); + try std.testing.expectEqualStrings("abc123", hash.?); + + // Wrong mtime should return null + const stale = cache.getHash("src/main.py", 1708369201); + try std.testing.expect(stale == null); + + // needsHash should detect stale entries + try std.testing.expect(cache.needsHash("src/main.py", 1708369201)); + try std.testing.expect(!cache.needsHash("src/main.py", 1708369200)); +} + +test "HashCache clear" { + const allocator = std.testing.allocator; + + var cache = HashCache.init(allocator); + defer cache.deinit(); + + try cache.putHash("file1.py", 123, "hash1"); + try cache.putHash("file2.py", 456, "hash2"); + + try std.testing.expectEqual(@as(usize, 2), cache.getStats().entries); + + cache.clear(); + + try std.testing.expectEqual(@as(usize, 0), cache.getStats().entries); + try std.testing.expect(cache.getStats().dirty); +} diff --git a/cli/src/utils/ignore.zig b/cli/src/utils/ignore.zig new file mode 100644 index 0000000..79acc1b --- /dev/null +++ b/cli/src/utils/ignore.zig @@ -0,0 +1,261 @@ +const std = @import("std"); + +/// Pattern type for ignore rules +const Pattern = struct { + pattern: []const u8, + is_negation: bool, // true if pattern starts with ! + is_dir_only: bool, // true if pattern ends with / + anchored: bool, // true if pattern contains / (not at start) +}; + +/// GitIgnore matcher for filtering files during directory traversal +pub const GitIgnore = struct { + patterns: std.ArrayList(Pattern), + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) GitIgnore { + return .{ + .patterns = std.ArrayList(Pattern).init(allocator), + .allocator = allocator, + }; + } + + pub fn deinit(self: *GitIgnore) void { + for (self.patterns.items) |p| { + self.allocator.free(p.pattern); + } + self.patterns.deinit(); + } + + /// Load .gitignore or .mlignore from directory + pub fn loadFromDir(self: *GitIgnore, dir_path: []const u8, filename: []const u8) !void { + const path = try std.fs.path.join(self.allocator, &[_][]const u8{ dir_path, filename }); + defer self.allocator.free(path); + + const file = std.fs.cwd().openFile(path, .{}) catch |err| switch (err) { + error.FileNotFound => return, // No ignore file is fine + else => return err, + }; + defer file.close(); + + const content = try file.readToEndAlloc(self.allocator, 1024 * 1024); // Max 1MB + defer self.allocator.free(content); + + try self.parse(content); + } + + /// Parse ignore patterns from content + pub fn parse(self: *GitIgnore, content: []const u8) !void { + var lines = std.mem.split(u8, content, "\n"); + while (lines.next()) |line| { + const trimmed = std.mem.trim(u8, line, " \t\r"); + + // Skip empty lines and comments + if (trimmed.len == 0 or std.mem.startsWith(u8, trimmed, "#")) continue; + + try self.addPattern(trimmed); + } + } + + /// Add a single pattern + fn addPattern(self: *GitIgnore, pattern: []const u8) !void { + var p = pattern; + var is_negation = false; + var is_dir_only = false; + + // Check for negation + if (std.mem.startsWith(u8, p, "!")) { + is_negation = true; + p = p[1..]; + } + + // Check for directory-only marker + if (std.mem.endsWith(u8, p, "/")) { + is_dir_only = true; + p = p[0 .. p.len - 1]; + } + + // Remove leading slash (anchored patterns) + const anchored = std.mem.indexOf(u8, p, "/") != null; + if (std.mem.startsWith(u8, p, "/")) { + p = p[1..]; + } + + // Store normalized pattern + const pattern_copy = try self.allocator.dupe(u8, p); + try self.patterns.append(.{ + .pattern = pattern_copy, + .is_negation = is_negation, + .is_dir_only = is_dir_only, + .anchored = anchored, + }); + } + + /// Check if a path should be ignored + pub fn isIgnored(self: *GitIgnore, path: []const u8, is_dir: bool) bool { + var ignored = false; + + for (self.patterns.items) |pattern| { + if (self.matches(pattern, path, is_dir)) { + ignored = !pattern.is_negation; + } + } + + return ignored; + } + + /// Check if a single pattern matches + fn matches(self: *GitIgnore, pattern: Pattern, path: []const u8, is_dir: bool) bool { + _ = self; + + // Directory-only patterns only match directories + if (pattern.is_dir_only and !is_dir) return false; + + // Convert gitignore pattern to glob + if (patternMatch(pattern.pattern, path)) { + return true; + } + + // Also check basename match for non-anchored patterns + if (!pattern.anchored) { + if (std.mem.lastIndexOf(u8, path, "/")) |idx| { + const basename = path[idx + 1 ..]; + if (patternMatch(pattern.pattern, basename)) { + return true; + } + } + } + + return false; + } + + /// Simple glob pattern matching + fn patternMatch(pattern: []const u8, path: []const u8) bool { + var p_idx: usize = 0; + var s_idx: usize = 0; + + while (p_idx < pattern.len) { + const p_char = pattern[p_idx]; + + if (p_char == '*') { + // Handle ** (matches any number of directories) + if (p_idx + 1 < pattern.len and pattern[p_idx + 1] == '*') { + // ** matches everything + return true; + } + + // Single * matches anything until next / or end + p_idx += 1; + if (p_idx >= pattern.len) { + // * at end - match rest of path + return true; + } + + const next_char = pattern[p_idx]; + while (s_idx < path.len and path[s_idx] != next_char) { + s_idx += 1; + } + } else if (p_char == '?') { + // ? matches single character + if (s_idx >= path.len) return false; + p_idx += 1; + s_idx += 1; + } else { + // Literal character match + if (s_idx >= path.len or path[s_idx] != p_char) return false; + p_idx += 1; + s_idx += 1; + } + } + + return s_idx == path.len; + } +}; + +/// Default patterns always ignored (like git does) +pub const DEFAULT_IGNORES = [_][]const u8{ + ".git", + ".ml", + "__pycache__", + "*.pyc", + "*.pyo", + ".DS_Store", + "node_modules", + ".venv", + "venv", + ".env", + ".idea", + ".vscode", + "*.log", + "*.tmp", + "*.swp", + "*.swo", + "*~", +}; + +/// Check if path matches default ignores +pub fn matchesDefaultIgnore(path: []const u8) bool { + // Check exact matches + for (DEFAULT_IGNORES) |pattern| { + if (std.mem.eql(u8, path, pattern)) return true; + } + + // Check suffix matches for patterns like *.pyc + if (std.mem.lastIndexOf(u8, path, "/")) |idx| { + const basename = path[idx + 1 ..]; + for (DEFAULT_IGNORES) |pattern| { + if (std.mem.startsWith(u8, pattern, "*.")) { + const ext = pattern[1..]; // Get extension including dot + if (std.mem.endsWith(u8, basename, ext)) return true; + } + } + } + + return false; +} + +test "GitIgnore basic patterns" { + const allocator = std.testing.allocator; + + var gi = GitIgnore.init(allocator); + defer gi.deinit(); + + try gi.parse("node_modules\n__pycache__\n*.pyc\n"); + + try std.testing.expect(gi.isIgnored("node_modules", true)); + try std.testing.expect(gi.isIgnored("__pycache__", true)); + try std.testing.expect(gi.isIgnored("test.pyc", false)); + try std.testing.expect(!gi.isIgnored("main.py", false)); +} + +test "GitIgnore negation" { + const allocator = std.testing.allocator; + + var gi = GitIgnore.init(allocator); + defer gi.deinit(); + + try gi.parse("*.log\n!important.log\n"); + + try std.testing.expect(gi.isIgnored("debug.log", false)); + try std.testing.expect(!gi.isIgnored("important.log", false)); +} + +test "GitIgnore directory-only" { + const allocator = std.testing.allocator; + + var gi = GitIgnore.init(allocator); + defer gi.deinit(); + + try gi.parse("build/\n"); + + try std.testing.expect(gi.isIgnored("build", true)); + try std.testing.expect(!gi.isIgnored("build", false)); +} + +test "matchesDefaultIgnore" { + try std.testing.expect(matchesDefaultIgnore(".git")); + try std.testing.expect(matchesDefaultIgnore("__pycache__")); + try std.testing.expect(matchesDefaultIgnore("node_modules")); + try std.testing.expect(matchesDefaultIgnore("test.pyc")); + try std.testing.expect(!matchesDefaultIgnore("main.py")); +} diff --git a/cli/src/utils/native_hash.zig b/cli/src/utils/native_hash.zig new file mode 100644 index 0000000..6bb5c8d --- /dev/null +++ b/cli/src/utils/native_hash.zig @@ -0,0 +1,195 @@ +const std = @import("std"); +const c = @cImport({ + @cInclude("dataset_hash.h"); +}); + +/// Native hash context for high-performance file hashing +pub const NativeHasher = struct { + ctx: *c.fh_context_t, + allocator: std.mem.Allocator, + + /// Initialize native hasher with thread pool + /// num_threads: 0 = auto-detect (use hardware concurrency) + pub fn init(allocator: std.mem.Allocator, num_threads: u32) !NativeHasher { + const ctx = c.fh_init(num_threads); + if (ctx == null) return error.NativeInitFailed; + + return .{ + .ctx = ctx, + .allocator = allocator, + }; + } + + /// Cleanup native hasher and thread pool + pub fn deinit(self: *NativeHasher) void { + c.fh_cleanup(self.ctx); + } + + /// Hash a single file + pub fn hashFile(self: *NativeHasher, path: []const u8) ![]const u8 { + const c_path = try self.allocator.dupeZ(u8, path); + defer self.allocator.free(c_path); + + const result = c.fh_hash_file(self.ctx, c_path.ptr); + if (result == null) return error.HashFailed; + defer c.fh_free_string(result); + + return try self.allocator.dupe(u8, std.mem.span(result)); + } + + /// Batch hash multiple files (amortizes CGo overhead) + pub fn hashBatch(self: *NativeHasher, paths: []const []const u8) ![][]const u8 { + // Convert paths to C string array + const c_paths = try self.allocator.alloc([*c]const u8, paths.len); + defer self.allocator.free(c_paths); + + for (paths, 0..) |path, i| { + const c_path = try self.allocator.dupeZ(u8, path); + c_paths[i] = c_path.ptr; + // Note: we need to keep these alive until after fh_hash_batch + } + defer { + for (c_paths) |p| { + self.allocator.free(std.mem.span(p)); + } + } + + // Allocate results array + const results = try self.allocator.alloc([*c]u8, paths.len); + defer self.allocator.free(results); + + // Call native batch hash + const ret = c.fh_hash_batch(self.ctx, c_paths.ptr, @intCast(paths.len), results.ptr); + if (ret != 0) return error.HashFailed; + + // Convert results to Zig strings + var hashes = try self.allocator.alloc([]const u8, paths.len); + errdefer { + for (hashes) |h| self.allocator.free(h); + self.allocator.free(hashes); + } + + for (results, 0..) |r, i| { + hashes[i] = try self.allocator.dupe(u8, std.mem.span(r)); + c.fh_free_string(r); + } + + return hashes; + } + + /// Hash entire directory (combined hash) + pub fn hashDirectory(self: *NativeHasher, dir_path: []const u8) ![]const u8 { + const c_path = try self.allocator.dupeZ(u8, dir_path); + defer self.allocator.free(c_path); + + const result = c.fh_hash_directory(self.ctx, c_path.ptr); + if (result == null) return error.HashFailed; + defer c.fh_free_string(result); + + return try self.allocator.dupe(u8, std.mem.span(result)); + } + + /// Hash directory with batch output (individual file hashes) + pub fn hashDirectoryBatch( + self: *NativeHasher, + dir_path: []const u8, + max_results: u32, + ) !struct { hashes: [][]const u8, paths: [][]const u8, count: u32 } { + const c_path = try self.allocator.dupeZ(u8, dir_path); + defer self.allocator.free(c_path); + + // Allocate output arrays + const hashes = try self.allocator.alloc([*c]u8, max_results); + defer self.allocator.free(hashes); + + const paths = try self.allocator.alloc([*c]u8, max_results); + defer self.allocator.free(paths); + + var count: u32 = 0; + + const ret = c.fh_hash_directory_batch( + self.ctx, + c_path.ptr, + hashes.ptr, + paths.ptr, + max_results, + &count, + ); + if (ret != 0) return error.HashFailed; + + // Convert to Zig arrays + var zig_hashes = try self.allocator.alloc([]const u8, count); + errdefer { + for (zig_hashes) |h| self.allocator.free(h); + self.allocator.free(zig_hashes); + } + + var zig_paths = try self.allocator.alloc([]const u8, count); + errdefer { + for (zig_paths) |p| self.allocator.free(p); + self.allocator.free(zig_paths); + } + + for (0..count) |i| { + zig_hashes[i] = try self.allocator.dupe(u8, std.mem.span(hashes[i])); + c.fh_free_string(hashes[i]); + + zig_paths[i] = try self.allocator.dupe(u8, std.mem.span(paths[i])); + c.fh_free_string(paths[i]); + } + + return .{ + .hashes = zig_hashes, + .paths = zig_paths, + .count = count, + }; + } + + /// Check if SIMD SHA-256 is available + pub fn hasSimd(self: *NativeHasher) bool { + _ = self; + return c.fh_has_simd_sha256() != 0; + } + + /// Get implementation info (SIMD type, etc.) + pub fn getImplInfo(self: *NativeHasher) []const u8 { + _ = self; + return std.mem.span(c.fh_get_simd_impl_name()); + } +}; + +/// Convenience function: hash directory using native library +pub fn hashDirectoryNative(allocator: std.mem.Allocator, dir_path: []const u8) ![]const u8 { + var hasher = try NativeHasher.init(allocator, 0); // Auto-detect threads + defer hasher.deinit(); + return try hasher.hashDirectory(dir_path); +} + +/// Convenience function: batch hash files using native library +pub fn hashFilesNative( + allocator: std.mem.Allocator, + paths: []const []const u8, +) ![][]const u8 { + var hasher = try NativeHasher.init(allocator, 0); + defer hasher.deinit(); + return try hasher.hashBatch(paths); +} + +test "NativeHasher basic operations" { + const allocator = std.testing.allocator; + + // Skip if native library not available + var hasher = NativeHasher.init(allocator, 1) catch |err| { + if (err == error.NativeInitFailed) { + std.debug.print("Native library not available, skipping test\n", .{}); + return; + } + return err; + }; + defer hasher.deinit(); + + // Check SIMD availability + const has_simd = hasher.hasSimd(); + const impl_name = hasher.getImplInfo(); + std.debug.print("SIMD: {any}, Impl: {s}\n", .{ has_simd, impl_name }); +} diff --git a/cli/src/utils/parallel_walk.zig b/cli/src/utils/parallel_walk.zig new file mode 100644 index 0000000..68cc964 --- /dev/null +++ b/cli/src/utils/parallel_walk.zig @@ -0,0 +1,231 @@ +const std = @import("std"); +const ignore = @import("ignore.zig"); + +/// Thread-safe work queue for parallel directory walking +const WorkQueue = struct { + items: std.ArrayList(WorkItem), + mutex: std.Thread.Mutex, + condition: std.Thread.Condition, + done: bool, + + const WorkItem = struct { + path: []const u8, + depth: usize, + }; + + fn init(allocator: std.mem.Allocator) WorkQueue { + return .{ + .items = std.ArrayList(WorkItem).init(allocator), + .mutex = .{}, + .condition = .{}, + .done = false, + }; + } + + fn deinit(self: *WorkQueue, allocator: std.mem.Allocator) void { + for (self.items.items) |item| { + allocator.free(item.path); + } + self.items.deinit(); + } + + fn push(self: *WorkQueue, path: []const u8, depth: usize, allocator: std.mem.Allocator) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + + try self.items.append(.{ + .path = try allocator.dupe(u8, path), + .depth = depth, + }); + self.condition.signal(); + } + + fn pop(self: *WorkQueue) ?WorkItem { + self.mutex.lock(); + defer self.mutex.unlock(); + + while (self.items.items.len == 0 and !self.done) { + self.condition.wait(&self.mutex); + } + + if (self.items.items.len == 0) return null; + return self.items.pop(); + } + + fn setDone(self: *WorkQueue) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.done = true; + self.condition.broadcast(); + } +}; + +/// Result from parallel directory walk +const WalkResult = struct { + files: std.ArrayList([]const u8), + mutex: std.Thread.Mutex, + + fn init(allocator: std.mem.Allocator) WalkResult { + return .{ + .files = std.ArrayList([]const u8).init(allocator), + .mutex = .{}, + }; + } + + fn deinit(self: *WalkResult, allocator: std.mem.Allocator) void { + for (self.files.items) |file| { + allocator.free(file); + } + self.files.deinit(); + } + + fn add(self: *WalkResult, path: []const u8, allocator: std.mem.Allocator) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + try self.files.append(try allocator.dupe(u8, path)); + } +}; + +/// Thread context for parallel walking +const ThreadContext = struct { + queue: *WorkQueue, + result: *WalkResult, + gitignore: *ignore.GitIgnore, + base_path: []const u8, + allocator: std.mem.Allocator, + max_depth: usize, +}; + +/// Worker thread function for parallel directory walking +fn walkWorker(ctx: *ThreadContext) void { + while (true) { + const item = ctx.queue.pop() orelse break; + defer ctx.allocator.free(item.path); + + if (item.depth >= ctx.max_depth) continue; + + walkDirectoryParallel(ctx, item.path, item.depth) catch |err| { + std.log.warn("Error walking {s}: {any}", .{ item.path, err }); + }; + } +} + +/// Walk a single directory and add subdirectories to queue +fn walkDirectoryParallel(ctx: *ThreadContext, dir_path: []const u8, depth: usize) !void { + const full_path = if (std.mem.eql(u8, dir_path, ".")) + ctx.base_path + else + try std.fs.path.join(ctx.allocator, &[_][]const u8{ ctx.base_path, dir_path }); + defer if (!std.mem.eql(u8, dir_path, ".")) ctx.allocator.free(full_path); + + var dir = std.fs.cwd().openDir(full_path, .{ .iterate = true }) catch |err| switch (err) { + error.AccessDenied => return, + error.FileNotFound => return, + else => return err, + }; + defer dir.close(); + + var it = dir.iterate(); + while (true) { + const entry = it.next() catch |err| switch (err) { + error.AccessDenied => continue, + else => return err, + } orelse break; + + const entry_path = if (std.mem.eql(u8, dir_path, ".")) + try std.fmt.allocPrint(ctx.allocator, "{s}", .{entry.name}) + else + try std.fs.path.join(ctx.allocator, &[_][]const u8{ dir_path, entry.name }); + defer ctx.allocator.free(entry_path); + + // Check default ignores + if (ignore.matchesDefaultIgnore(entry_path)) continue; + + // Check gitignore patterns + const is_dir = entry.kind == .directory; + if (ctx.gitignore.isIgnored(entry_path, is_dir)) continue; + + if (is_dir) { + // Add subdirectory to work queue + ctx.queue.push(entry_path, depth + 1, ctx.allocator) catch |err| { + std.log.warn("Failed to queue {s}: {any}", .{ entry_path, err }); + }; + } else if (entry.kind == .file) { + // Add file to results + ctx.result.add(entry_path, ctx.allocator) catch |err| { + std.log.warn("Failed to add file {s}: {any}", .{ entry_path, err }); + }; + } + } +} + +/// Parallel directory walker that uses multiple threads +pub fn parallelWalk( + allocator: std.mem.Allocator, + base_path: []const u8, + gitignore: *ignore.GitIgnore, + num_threads: usize, +) ![][]const u8 { + var queue = WorkQueue.init(allocator); + defer queue.deinit(allocator); + + var result = WalkResult.init(allocator); + defer result.deinit(allocator); + + // Start with base directory + try queue.push(".", 0, allocator); + + // Create thread context + var ctx = ThreadContext{ + .queue = &queue, + .result = &result, + .gitignore = gitignore, + .base_path = base_path, + .allocator = allocator, + .max_depth = 100, // Prevent infinite recursion + }; + + // Spawn worker threads + var threads = try allocator.alloc(std.Thread, num_threads); + defer allocator.free(threads); + + for (0..num_threads) |i| { + threads[i] = try std.Thread.spawn(.{}, walkWorker, .{&ctx}); + } + + // Wait for all workers to complete + for (threads) |thread| { + thread.join(); + } + + // Sort results for deterministic ordering + std.sort.block([]const u8, result.files.items, {}, struct { + fn lessThan(_: void, a: []const u8, b: []const u8) bool { + return std.mem.order(u8, a, b) == .lt; + } + }.lessThan); + + // Transfer ownership to caller + const files = try allocator.alloc([]const u8, result.files.items.len); + @memcpy(files, result.files.items); + result.files.items.len = 0; // Prevent deinit from freeing + + return files; +} + +test "parallelWalk basic" { + const allocator = std.testing.allocator; + + var gitignore = ignore.GitIgnore.init(allocator); + defer gitignore.deinit(); + + // Walk the current directory with 4 threads + const files = try parallelWalk(allocator, ".", &gitignore, 4); + defer { + for (files) |f| allocator.free(f); + allocator.free(files); + } + + // Should find at least some files + try std.testing.expect(files.len > 0); +} diff --git a/cli/src/utils/watch.zig b/cli/src/utils/watch.zig new file mode 100644 index 0000000..2c7aa92 --- /dev/null +++ b/cli/src/utils/watch.zig @@ -0,0 +1,283 @@ +const std = @import("std"); +const os = std.os; +const log = std.log; + +/// File watcher using OS-native APIs (kqueue on macOS, inotify on Linux) +/// Zero third-party dependencies - uses only standard library OS bindings +pub const FileWatcher = struct { + allocator: std.mem.Allocator, + watched_paths: std.StringHashMap(void), + + // Platform-specific handles + kqueue_fd: if (@import("builtin").target.os.tag == .macos) i32 else void, + inotify_fd: if (@import("builtin").target.os.tag == .linux) i32 else void, + + // Debounce timer + last_event_time: i64, + debounce_ms: i64, + + pub fn init(allocator: std.mem.Allocator, debounce_ms: i64) !FileWatcher { + const target = @import("builtin").target; + + var watcher = FileWatcher{ + .allocator = allocator, + .watched_paths = std.StringHashMap(void).init(allocator), + .kqueue_fd = if (target.os.tag == .macos) -1 else {}, + .inotify_fd = if (target.os.tag == .linux) -1 else {}, + .last_event_time = 0, + .debounce_ms = debounce_ms, + }; + + switch (target.os.tag) { + .macos => { + watcher.kqueue_fd = try os.kqueue(); + }, + .linux => { + watcher.inotify_fd = try std.os.inotify_init1(os.linux.IN_CLOEXEC); + }, + else => { + return error.UnsupportedPlatform; + }, + } + + return watcher; + } + + pub fn deinit(self: *FileWatcher) void { + const target = @import("builtin").target; + + switch (target.os.tag) { + .macos => { + if (self.kqueue_fd != -1) { + os.close(self.kqueue_fd); + } + }, + .linux => { + if (self.inotify_fd != -1) { + os.close(self.inotify_fd); + } + }, + else => {}, + } + + var it = self.watched_paths.keyIterator(); + while (it.next()) |key| { + self.allocator.free(key.*); + } + self.watched_paths.deinit(); + } + + /// Add a directory to watch recursively + pub fn watchDirectory(self: *FileWatcher, path: []const u8) !void { + if (self.watched_paths.contains(path)) return; + + const path_copy = try self.allocator.dupe(u8, path); + try self.watched_paths.put(path_copy, {}); + + const target = @import("builtin").target; + + switch (target.os.tag) { + .macos => try self.addKqueueWatch(path), + .linux => try self.addInotifyWatch(path), + else => return error.UnsupportedPlatform, + } + + // Recursively watch subdirectories + var dir = std.fs.cwd().openDir(path, .{ .iterate = true }) catch |err| switch (err) { + error.AccessDenied => return, + error.FileNotFound => return, + else => return err, + }; + defer dir.close(); + + var it = dir.iterate(); + while (true) { + const entry = it.next() catch |err| switch (err) { + error.AccessDenied => continue, + else => return err, + } orelse break; + + if (entry.kind == .directory) { + const subpath = try std.fs.path.join(self.allocator, &[_][]const u8{ path, entry.name }); + defer self.allocator.free(subpath); + try self.watchDirectory(subpath); + } + } + } + + /// Add kqueue watch for macOS + fn addKqueueWatch(self: *FileWatcher, path: []const u8) !void { + if (@import("builtin").target.os.tag != .macos) return; + + const fd = try os.open(path, os.O.EVTONLY | os.O_RDONLY, 0); + defer os.close(fd); + + const event = os.Kevent{ + .ident = @intCast(fd), + .filter = os.EVFILT_VNODE, + .flags = os.EV_ADD | os.EV_CLEAR, + .fflags = os.NOTE_WRITE | os.NOTE_EXTEND | os.NOTE_ATTRIB | os.NOTE_LINK | os.NOTE_RENAME | os.NOTE_REVOKE, + .data = 0, + .udata = 0, + }; + + const changes = [_]os.Kevent{event}; + _ = try os.kevent(self.kqueue_fd, &changes, &.{}, null); + } + + /// Add inotify watch for Linux + fn addInotifyWatch(self: *FileWatcher, path: []const u8) !void { + if (@import("builtin").target.os.tag != .linux) return; + + const mask = os.linux.IN_MODIFY | os.linux.IN_CREATE | os.linux.IN_DELETE | + os.linux.IN_MOVED_FROM | os.linux.IN_MOVED_TO | os.linux.IN_ATTRIB; + + const wd = try os.linux.inotify_add_watch(self.inotify_fd, path.ptr, mask); + if (wd < 0) return error.InotifyError; + } + + /// Wait for file changes with debouncing + pub fn waitForChanges(self: *FileWatcher, timeout_ms: i32) !bool { + const target = @import("builtin").target; + + switch (target.os.tag) { + .macos => return try self.waitKqueue(timeout_ms), + .linux => return try self.waitInotify(timeout_ms), + else => return error.UnsupportedPlatform, + } + } + + /// kqueue wait implementation + fn waitKqueue(self: *FileWatcher, timeout_ms: i32) !bool { + if (@import("builtin").target.os.tag != .macos) return false; + + var ts: os.timespec = undefined; + if (timeout_ms >= 0) { + ts.tv_sec = @divTrunc(timeout_ms, 1000); + ts.tv_nsec = @mod(timeout_ms, 1000) * 1000000; + } + + var events: [10]os.Kevent = undefined; + const nev = os.kevent(self.kqueue_fd, &.{}, &events, if (timeout_ms >= 0) &ts else null) catch |err| switch (err) { + error.ETIME => return false, // Timeout + else => return err, + }; + + if (nev > 0) { + const now = std.time.milliTimestamp(); + if (now - self.last_event_time > self.debounce_ms) { + self.last_event_time = now; + return true; + } + } + + return false; + } + + /// inotify wait implementation + fn waitInotify(self: *FileWatcher, timeout_ms: i32) !bool { + if (@import("builtin").target.os.tag != .linux) return false; + + var fds = [_]os.pollfd{.{ + .fd = self.inotify_fd, + .events = os.POLLIN, + .revents = 0, + }}; + + const ready = os.poll(&fds, timeout_ms) catch |err| switch (err) { + error.ETIME => return false, + else => return err, + }; + + if (ready > 0 and (fds[0].revents & os.POLLIN) != 0) { + var buf: [4096]u8 align(@alignOf(os.linux.inotify_event)) = undefined; + + const bytes_read = try os.read(self.inotify_fd, &buf); + if (bytes_read > 0) { + const now = std.time.milliTimestamp(); + if (now - self.last_event_time > self.debounce_ms) { + self.last_event_time = now; + return true; + } + } + } + + return false; + } + + /// Run watch loop with callback + pub fn run(self: *FileWatcher, callback: fn () void) !void { + log.info("Watching for file changes (debounce: {d}ms)...", .{self.debounce_ms}); + + while (true) { + if (try self.waitForChanges(-1)) { + callback(); + } + } + } +}; + +/// Watch command handler +pub fn watchCommand(allocator: std.mem.Allocator, args: []const []const u8) !void { + if (args.len < 1) { + std.debug.print("Usage: ml watch [--sync] [--queue]\n", .{}); + return error.InvalidArgs; + } + + const path = args[0]; + var auto_sync = false; + var auto_queue = false; + + // Parse flags + for (args[1..]) |arg| { + if (std.mem.eql(u8, arg, "--sync")) auto_sync = true; + if (std.mem.eql(u8, arg, "--queue")) auto_queue = true; + } + + var watcher = try FileWatcher.init(allocator, 100); // 100ms debounce + defer watcher.deinit(); + + try watcher.watchDirectory(path); + + log.info("Watching {s} for changes...", .{path}); + + // Callback for file changes + const CallbackContext = struct { + allocator: std.mem.Allocator, + path: []const u8, + auto_sync: bool, + auto_queue: bool, + }; + + const ctx = CallbackContext{ + .allocator = allocator, + .path = path, + .auto_sync = auto_sync, + .auto_queue = auto_queue, + }; + + // Run watch loop + while (true) { + if (try watcher.waitForChanges(-1)) { + log.info("File changes detected", .{}); + + if (auto_sync) { + log.info("Auto-syncing...", .{}); + // Trigger sync (implementation would call sync command) + _ = ctx; + } + + if (auto_queue) { + log.info("Auto-queuing...", .{}); + // Trigger queue (implementation would call queue command) + } + } + } +} + +test "FileWatcher init/deinit" { + const allocator = std.testing.allocator; + + var watcher = try FileWatcher.init(allocator, 100); + defer watcher.deinit(); +} diff --git a/cmd/tui/internal/config/cli_config.go b/cmd/tui/internal/config/cli_config.go index 8fa3a90..fe12ff5 100644 --- a/cmd/tui/internal/config/cli_config.go +++ b/cmd/tui/internal/config/cli_config.go @@ -356,7 +356,8 @@ api_key = "your_api_key_here" # Your API key (get from admin) // Set proper permissions if err := auth.CheckConfigFilePermissions(configPath); err != nil { - log.Printf("Warning: %v", err) + // Log permission warning but don't fail + _ = err } return nil diff --git a/cmd/tui/internal/config/config.go b/cmd/tui/internal/config/config.go index bd4bf5d..59eb694 100644 --- a/cmd/tui/internal/config/config.go +++ b/cmd/tui/internal/config/config.go @@ -17,12 +17,14 @@ type Config struct { SSHKey string `toml:"ssh_key"` Port int `toml:"port"` BasePath string `toml:"base_path"` + Mode string `toml:"mode"` // "dev" or "prod" WrapperScript string `toml:"wrapper_script"` TrainScript string `toml:"train_script"` RedisAddr string `toml:"redis_addr"` RedisPassword string `toml:"redis_password"` RedisDB int `toml:"redis_db"` KnownHosts string `toml:"known_hosts"` + ServerURL string `toml:"server_url"` // WebSocket server URL (e.g., ws://localhost:8080) // Authentication Auth auth.Config `toml:"auth"` @@ -133,6 +135,20 @@ func (c *Config) Validate() error { } } + // Set default mode if not specified + if c.Mode == "" { + if os.Getenv("FETCH_ML_TUI_MODE") != "" { + c.Mode = os.Getenv("FETCH_ML_TUI_MODE") + } else { + c.Mode = "dev" // Default to dev mode + } + } + + // Set mode-appropriate default paths using project-relative paths + if c.BasePath == "" { + c.BasePath = utils.ModeBasedBasePath(c.Mode) + } + if c.BasePath != "" { // Convert relative paths to absolute c.BasePath = utils.ExpandPath(c.BasePath) diff --git a/cmd/tui/internal/controller/commands.go b/cmd/tui/internal/controller/commands.go index c49d311..25dfa3a 100644 --- a/cmd/tui/internal/controller/commands.go +++ b/cmd/tui/internal/controller/commands.go @@ -24,6 +24,7 @@ func (c *Controller) loadAllData() tea.Cmd { c.loadQueue(), c.loadGPU(), c.loadContainer(), + c.loadDatasets(), ) } @@ -39,6 +40,13 @@ func (c *Controller) loadJobs() tea.Cmd { var jobs []model.Job statusChan := make(chan []model.Job, 4) + // Debug: Print paths being used + c.logger.Info("Loading jobs from paths", + "pending", c.getPathForStatus(model.StatusPending), + "running", c.getPathForStatus(model.StatusRunning), + "finished", c.getPathForStatus(model.StatusFinished), + "failed", c.getPathForStatus(model.StatusFailed)) + for _, status := range []model.JobStatus{ model.StatusPending, model.StatusRunning, @@ -48,22 +56,18 @@ func (c *Controller) loadJobs() tea.Cmd { go func(s model.JobStatus) { path := c.getPathForStatus(s) names := c.server.ListDir(path) + + // Debug: Log what we found + c.logger.Info("Listed directory", "status", s, "path", path, "count", len(names)) + var statusJobs []model.Job for _, name := range names { - jobStatus, _ := c.taskQueue.GetJobStatus(name) - taskID := jobStatus["task_id"] - priority := int64(0) - if p, ok := jobStatus["priority"]; ok { - _, err := fmt.Sscanf(p, "%d", &priority) - if err != nil { - priority = 0 - } - } + // Lazy loading: only fetch basic info for list view + // Full details (GPU, narrative) loaded on selection statusJobs = append(statusJobs, model.Job{ - Name: name, - Status: s, - TaskID: taskID, - Priority: priority, + Name: name, + Status: s, + // TaskID, Priority, GPU info loaded lazily }) } statusChan <- statusJobs @@ -85,6 +89,24 @@ func (c *Controller) loadJobs() tea.Cmd { } } +// loadJobDetails loads full details for a specific job (lazy loading) +func (c *Controller) loadJobDetails(jobName string) tea.Cmd { + return func() tea.Msg { + jobStatus, _ := c.taskQueue.GetJobStatus(jobName) + + // Parse priority + priority := int64(0) + if p, ok := jobStatus["priority"]; ok { + fmt.Sscanf(p, "%d", &priority) + } + + // Build full job with details + // This is called when job is selected for detailed view + + return model.StatusMsg{Text: "Loaded details for " + jobName, Level: "info"} + } +} + func (c *Controller) loadQueue() tea.Cmd { return func() tea.Msg { tasks, err := c.taskQueue.GetQueuedTasks() @@ -362,6 +384,18 @@ func (c *Controller) showQueue(m model.State) tea.Cmd { } } +func (c *Controller) loadDatasets() tea.Cmd { + return func() tea.Msg { + datasets, err := c.taskQueue.ListDatasets() + if err != nil { + c.logger.Error("failed to load datasets", "error", err) + return model.StatusMsg{Text: "Failed to load datasets: " + err.Error(), Level: "error"} + } + c.logger.Info("loaded datasets", "count", len(datasets)) + return model.DatasetsLoadedMsg(datasets) + } +} + func tickCmd() tea.Cmd { return tea.Tick(time.Second, func(t time.Time) tea.Msg { return model.TickMsg(t) diff --git a/cmd/tui/internal/controller/controller.go b/cmd/tui/internal/controller/controller.go index 8f2a6d0..4ddfc66 100644 --- a/cmd/tui/internal/controller/controller.go +++ b/cmd/tui/internal/controller/controller.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "strings" "time" "github.com/charmbracelet/bubbles/key" @@ -19,6 +20,7 @@ type Controller struct { server *services.MLServer taskQueue *services.TaskQueue logger *logging.Logger + wsClient *services.WebSocketClient } func (c *Controller) handleKeyMsg(msg tea.KeyMsg, m model.State) (model.State, tea.Cmd) { @@ -143,6 +145,33 @@ func (c *Controller) handleGlobalKeys(msg tea.KeyMsg, m *model.State) []tea.Cmd case key.Matches(msg, m.Keys.ViewExperiments): m.ActiveView = model.ViewModeExperiments cmds = append(cmds, c.loadExperiments()) + case key.Matches(msg, m.Keys.ViewNarrative): + m.ActiveView = model.ViewModeNarrative + if job := getSelectedJob(*m); job != nil { + m.SelectedJob = *job + } + case key.Matches(msg, m.Keys.ViewTeam): + m.ActiveView = model.ViewModeTeam + case key.Matches(msg, m.Keys.ViewExperimentHistory): + m.ActiveView = model.ViewModeExperimentHistory + cmds = append(cmds, c.loadExperimentHistory()) + case key.Matches(msg, m.Keys.ViewConfig): + m.ActiveView = model.ViewModeConfig + cmds = append(cmds, c.loadConfig()) + case key.Matches(msg, m.Keys.ViewLogs): + m.ActiveView = model.ViewModeLogs + if job := getSelectedJob(*m); job != nil { + cmds = append(cmds, c.loadLogs(job.Name)) + } + case key.Matches(msg, m.Keys.ViewExport): + if job := getSelectedJob(*m); job != nil { + cmds = append(cmds, c.exportJob(job.Name)) + } + case key.Matches(msg, m.Keys.FilterTeam): + m.InputMode = true + m.Input.SetValue("@") + m.Input.Focus() + m.Status = "Filter by team member: @alice, @bob, @team-ml" case key.Matches(msg, m.Keys.Cancel): if job := getSelectedJob(*m); job != nil && job.TaskID != "" { cmds = append(cmds, c.cancelTask(job.TaskID)) @@ -181,8 +210,18 @@ func (c *Controller) applyWindowSize(msg tea.WindowSizeMsg, m model.State) model m.QueueView.Height = listHeight - 4 m.SettingsView.Width = panelWidth m.SettingsView.Height = listHeight - 4 + m.NarrativeView.Width = panelWidth + m.NarrativeView.Height = listHeight - 4 + m.TeamView.Width = panelWidth + m.TeamView.Height = listHeight - 4 m.ExperimentsView.Width = panelWidth m.ExperimentsView.Height = listHeight - 4 + m.ExperimentHistoryView.Width = panelWidth + m.ExperimentHistoryView.Height = listHeight - 4 + m.ConfigView.Width = panelWidth + m.ConfigView.Height = listHeight - 4 + m.LogsView.Width = panelWidth + m.LogsView.Height = listHeight - 4 return m } @@ -245,7 +284,25 @@ func (c *Controller) handleStatusMsg(msg model.StatusMsg, m model.State) (model. func (c *Controller) handleTickMsg(msg model.TickMsg, m model.State) (model.State, tea.Cmd) { var cmds []tea.Cmd - if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading { + + // Calculate actual refresh rate + now := time.Now() + if !m.LastFrameTime.IsZero() { + elapsed := now.Sub(m.LastFrameTime).Milliseconds() + if elapsed > 0 { + // Smooth the rate with simple averaging + m.RefreshRate = (m.RefreshRate*float64(m.FrameCount) + float64(elapsed)) / float64(m.FrameCount+1) + m.FrameCount++ + if m.FrameCount > 100 { + m.FrameCount = 1 + m.RefreshRate = float64(elapsed) + } + } + } + m.LastFrameTime = now + + // 500ms refresh target for real-time updates + if time.Since(m.LastRefresh) > 500*time.Millisecond && !m.IsLoading { m.LastRefresh = time.Now() cmds = append(cmds, c.loadAllData()) } @@ -290,16 +347,25 @@ func New( tq *services.TaskQueue, logger *logging.Logger, ) *Controller { + // Create WebSocket client for real-time updates + wsClient := services.NewWebSocketClient(cfg.ServerURL, "", logger) + return &Controller{ config: cfg, server: srv, taskQueue: tq, logger: logger, + wsClient: wsClient, } } // Init initializes the TUI and returns initial commands func (c *Controller) Init() tea.Cmd { + // Connect WebSocket for real-time updates + if err := c.wsClient.Connect(); err != nil { + c.logger.Error("WebSocket connection failed", "error", err) + } + return tea.Batch( tea.SetWindowTitle("FetchML"), c.loadAllData(), @@ -307,14 +373,17 @@ func (c *Controller) Init() tea.Cmd { ) } -// Update handles all messages and updates the state func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) { switch typed := msg.(type) { case tea.KeyMsg: return c.handleKeyMsg(typed, m) case tea.WindowSizeMsg: - updated := c.applyWindowSize(typed, m) - return c.finalizeUpdate(msg, updated) + // Only apply window size on first render, then keep constant + if m.Width == 0 && m.Height == 0 { + updated := c.applyWindowSize(typed, m) + return c.finalizeUpdate(msg, updated) + } + return c.finalizeUpdate(msg, m) case model.JobsLoadedMsg: return c.handleJobsLoadedMsg(typed, m) case model.TasksLoadedMsg: @@ -323,8 +392,26 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) { return c.handleGPUContent(typed, m) case model.ContainerLoadedMsg: return c.handleContainerContent(typed, m) - case model.QueueLoadedMsg: - return c.handleQueueContent(typed, m) + case model.DatasetsLoadedMsg: + // Format datasets into view content + var content strings.Builder + content.WriteString("Available Datasets\n") + content.WriteString(strings.Repeat("โ•", 50) + "\n\n") + if len(typed) == 0 { + content.WriteString("๐Ÿ“ญ No datasets found\n\n") + content.WriteString("Datasets will appear here when available\n") + content.WriteString("in the data directory.") + } else { + for i, ds := range typed { + content.WriteString(fmt.Sprintf("%d. ๐Ÿ“ %s\n", i+1, ds.Name)) + content.WriteString(fmt.Sprintf(" Location: %s\n", ds.Location)) + content.WriteString(fmt.Sprintf(" Size: %d bytes\n", ds.SizeBytes)) + content.WriteString(fmt.Sprintf(" Last Access: %s\n\n", ds.LastAccess.Format("2006-01-02 15:04"))) + } + } + m.DatasetView.SetContent(content.String()) + m.DatasetView.GotoTop() + return c.finalizeUpdate(msg, m) case model.SettingsContentMsg: m.SettingsView.SetContent(string(typed)) return c.finalizeUpdate(msg, m) @@ -332,12 +419,36 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) { m.ExperimentsView.SetContent(string(typed)) m.ExperimentsView.GotoTop() return c.finalizeUpdate(msg, m) + case ExperimentHistoryLoadedMsg: + m.ExperimentHistoryView.SetContent(string(typed)) + m.ExperimentHistoryView.GotoTop() + return c.finalizeUpdate(msg, m) + case ConfigLoadedMsg: + m.ConfigView.SetContent(string(typed)) + m.ConfigView.GotoTop() + return c.finalizeUpdate(msg, m) + case LogsLoadedMsg: + m.LogsView.SetContent(string(typed)) + m.LogsView.GotoTop() + return c.finalizeUpdate(msg, m) case model.SettingsUpdateMsg: return c.finalizeUpdate(msg, m) case model.StatusMsg: return c.handleStatusMsg(typed, m) case model.TickMsg: return c.handleTickMsg(typed, m) + case model.JobUpdateMsg: + // Handle real-time job status updates from WebSocket + m.Status = fmt.Sprintf("Job %s: %s", typed.JobName, typed.Status) + // Refresh job list to show updated status + return m, c.loadAllData() + case model.GPUUpdateMsg: + // Throttle GPU updates to 1/second (humans can't perceive faster) + if time.Since(m.LastGPUUpdate) > 1*time.Second { + m.LastGPUUpdate = time.Now() + return c.finalizeUpdate(msg, m) + } + return m, nil default: return c.finalizeUpdate(msg, m) } @@ -346,6 +457,12 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) { // ExperimentsLoadedMsg is sent when experiments are loaded type ExperimentsLoadedMsg string +// ExperimentHistoryLoadedMsg is sent when experiment history is loaded +type ExperimentHistoryLoadedMsg string + +// ConfigLoadedMsg is sent when config is loaded +type ConfigLoadedMsg string + func (c *Controller) loadExperiments() tea.Cmd { return func() tea.Msg { commitIDs, err := c.taskQueue.ListExperiments() @@ -372,3 +489,92 @@ func (c *Controller) loadExperiments() tea.Cmd { return ExperimentsLoadedMsg(output) } } + +func (c *Controller) loadExperimentHistory() tea.Cmd { + return func() tea.Msg { + // Placeholder - will show experiment history with annotations + return ExperimentHistoryLoadedMsg("Experiment History & Annotations\n\n" + + "This view will show:\n" + + "- Previous experiment runs\n" + + "- Annotations and notes\n" + + "- Config snapshots\n" + + "- Side-by-side comparisons\n\n" + + "(Requires API: GET /api/experiments/:id/history)") + } +} + +func (c *Controller) loadConfig() tea.Cmd { + return func() tea.Msg { + // Build config diff showing changes from defaults + var output strings.Builder + output.WriteString("โš™๏ธ Config View (Read-Only)\n\n") + + output.WriteString("โ”Œโ”€ Changes from Defaults โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n") + changes := []string{} + + if c.config.Host != "" { + changes = append(changes, fmt.Sprintf("โ”‚ Host: %s", c.config.Host)) + } + if c.config.Port != 0 && c.config.Port != 22 { + changes = append(changes, fmt.Sprintf("โ”‚ Port: %d (default: 22)", c.config.Port)) + } + if c.config.BasePath != "" { + changes = append(changes, fmt.Sprintf("โ”‚ Base Path: %s", c.config.BasePath)) + } + if c.config.RedisAddr != "" && c.config.RedisAddr != "localhost:6379" { + changes = append(changes, fmt.Sprintf("โ”‚ Redis: %s (default: localhost:6379)", c.config.RedisAddr)) + } + if c.config.ServerURL != "" { + changes = append(changes, fmt.Sprintf("โ”‚ Server: %s", c.config.ServerURL)) + } + + if len(changes) == 0 { + output.WriteString("โ”‚ (Using all default settings)\n") + } else { + for _, change := range changes { + output.WriteString(change + "\n") + } + } + output.WriteString("โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n\n") + + output.WriteString("Full Configuration:\n") + output.WriteString(fmt.Sprintf(" Host: %s\n", c.config.Host)) + output.WriteString(fmt.Sprintf(" Port: %d\n", c.config.Port)) + output.WriteString(fmt.Sprintf(" Base Path: %s\n", c.config.BasePath)) + output.WriteString(fmt.Sprintf(" Redis: %s\n", c.config.RedisAddr)) + output.WriteString(fmt.Sprintf(" Server: %s\n", c.config.ServerURL)) + output.WriteString(fmt.Sprintf(" User: %s\n\n", c.config.User)) + + output.WriteString("Use CLI to modify: ml config set ") + + return ConfigLoadedMsg(output.String()) + } +} + +// LogsLoadedMsg is sent when logs are loaded +type LogsLoadedMsg string + +func (c *Controller) loadLogs(jobName string) tea.Cmd { + return func() tea.Msg { + // Placeholder - will stream logs from job + return LogsLoadedMsg("๐Ÿ“œ Logs for " + jobName + "\n\n" + + "Log streaming will appear here...\n\n" + + "(Requires API: GET /api/jobs/" + jobName + "/logs?follow=true)") + } +} + +// ExportCompletedMsg is sent when export is complete +type ExportCompletedMsg struct { + JobName string + Path string +} + +func (c *Controller) exportJob(jobName string) tea.Cmd { + return func() tea.Msg { + // Show export in progress + return model.StatusMsg{ + Text: "Exporting " + jobName + "... (anonymized)", + Level: "info", + } + } +} diff --git a/cmd/tui/internal/model/jobs.go b/cmd/tui/internal/model/jobs.go index 873c1c2..e7e563c 100644 --- a/cmd/tui/internal/model/jobs.go +++ b/cmd/tui/internal/model/jobs.go @@ -1,7 +1,11 @@ // Package model provides TUI data structures and state management package model -import "fmt" +import ( + "fmt" + + "github.com/charmbracelet/lipgloss" +) // JobStatus represents the status of a job type JobStatus string @@ -21,12 +25,23 @@ type Job struct { Status JobStatus TaskID string Priority int64 + // Narrative fields for research context + Hypothesis string + Context string + Intent string + ExpectedOutcome string + ActualOutcome string + OutcomeStatus string // validated, invalidated, inconclusive + // GPU allocation tracking + GPUDeviceID int // -1 if not assigned + GPUUtilization int // 0-100% + GPUMemoryUsed int64 // MB } // Title returns the job title for display func (j Job) Title() string { return j.Name } -// Description returns a formatted description with status icon +// Description returns a formatted description with status icon and GPU info func (j Job) Description() string { icon := map[JobStatus]string{ StatusPending: "โธ", @@ -39,7 +54,16 @@ func (j Job) Description() string { if j.Priority > 0 { pri = fmt.Sprintf(" [P%d]", j.Priority) } - return fmt.Sprintf("%s %s%s", icon, j.Status, pri) + gpu := "" + if j.GPUDeviceID >= 0 { + gpu = fmt.Sprintf(" [GPU:%d %d%%]", j.GPUDeviceID, j.GPUUtilization) + } + + // Apply status color to the status text + statusStyle := lipgloss.NewStyle().Foreground(StatusColor(j.Status)) + coloredStatus := statusStyle.Render(string(j.Status)) + + return fmt.Sprintf("%s %s%s%s", icon, coloredStatus, pri, gpu) } // FilterValue returns the value used for filtering diff --git a/cmd/tui/internal/model/keys.go b/cmd/tui/internal/model/keys.go index c6cd20b..c83a3e0 100644 --- a/cmd/tui/internal/model/keys.go +++ b/cmd/tui/internal/model/keys.go @@ -5,42 +5,56 @@ import "github.com/charmbracelet/bubbles/key" // KeyMap defines key bindings for the TUI type KeyMap struct { - Refresh key.Binding - Trigger key.Binding - TriggerArgs key.Binding - ViewQueue key.Binding - ViewContainer key.Binding - ViewGPU key.Binding - ViewJobs key.Binding - ViewDatasets key.Binding - ViewExperiments key.Binding - ViewSettings key.Binding - Cancel key.Binding - Delete key.Binding - MarkFailed key.Binding - RefreshGPU key.Binding - Help key.Binding - Quit key.Binding + Refresh key.Binding + Trigger key.Binding + TriggerArgs key.Binding + ViewQueue key.Binding + ViewContainer key.Binding + ViewGPU key.Binding + ViewJobs key.Binding + ViewDatasets key.Binding + ViewExperiments key.Binding + ViewSettings key.Binding + ViewNarrative key.Binding + ViewTeam key.Binding + ViewExperimentHistory key.Binding + ViewConfig key.Binding + ViewLogs key.Binding + ViewExport key.Binding + FilterTeam key.Binding + Cancel key.Binding + Delete key.Binding + MarkFailed key.Binding + RefreshGPU key.Binding + Help key.Binding + Quit key.Binding } // DefaultKeys returns the default key bindings for the TUI func DefaultKeys() KeyMap { return KeyMap{ - Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")), - Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")), - TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")), - ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")), - ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")), - ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")), - ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")), - ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")), - ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")), - Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")), - Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")), - MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")), - RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")), - ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")), - Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")), - Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")), + Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")), + Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")), + TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")), + ViewQueue: key.NewBinding(key.WithKeys("q"), key.WithHelp("q", "view queue")), + ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")), + ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")), + ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")), + ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")), + ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")), + ViewNarrative: key.NewBinding(key.WithKeys("n"), key.WithHelp("n", "narrative")), + ViewTeam: key.NewBinding(key.WithKeys("m"), key.WithHelp("m", "team")), + ViewExperimentHistory: key.NewBinding(key.WithKeys("e"), key.WithHelp("e", "experiment history")), + ViewConfig: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "config")), + ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")), + ViewLogs: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "logs")), + ViewExport: key.NewBinding(key.WithKeys("E"), key.WithHelp("E", "export job")), + FilterTeam: key.NewBinding(key.WithKeys("@"), key.WithHelp("@", "filter by team")), + Cancel: key.NewBinding(key.WithKeys("x"), key.WithHelp("x", "cancel task")), + Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")), + MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")), + RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")), + Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")), + Quit: key.NewBinding(key.WithKeys("ctrl+c"), key.WithHelp("ctrl+c", "quit")), } } diff --git a/cmd/tui/internal/model/messages.go b/cmd/tui/internal/model/messages.go index bcbcc60..b436907 100644 --- a/cmd/tui/internal/model/messages.go +++ b/cmd/tui/internal/model/messages.go @@ -9,6 +9,9 @@ type JobsLoadedMsg []Job // TasksLoadedMsg contains loaded tasks from the queue type TasksLoadedMsg []*Task +// DatasetsLoadedMsg contains loaded datasets +type DatasetsLoadedMsg []DatasetInfo + // GpuLoadedMsg contains GPU status information type GpuLoadedMsg string diff --git a/cmd/tui/internal/model/state.go b/cmd/tui/internal/model/state.go index fca3dbd..e7000d3 100644 --- a/cmd/tui/internal/model/state.go +++ b/cmd/tui/internal/model/state.go @@ -32,13 +32,18 @@ type ViewMode int // ViewMode constants represent different TUI views const ( - ViewModeJobs ViewMode = iota // Jobs view mode - ViewModeGPU // GPU status view mode - ViewModeQueue // Queue status view mode - ViewModeContainer // Container status view mode - ViewModeSettings // Settings view mode - ViewModeDatasets // Datasets view mode - ViewModeExperiments // Experiments view mode + ViewModeJobs ViewMode = iota // Jobs view mode + ViewModeGPU // GPU status view mode + ViewModeQueue // Queue status view mode + ViewModeContainer // Container status view mode + ViewModeSettings // Settings view mode + ViewModeDatasets // Datasets view mode + ViewModeExperiments // Experiments view mode + ViewModeNarrative // Narrative/Outcome view mode + ViewModeTeam // Team collaboration view mode + ViewModeExperimentHistory // Experiment history view mode + ViewModeConfig // Config view mode + ViewModeLogs // Logs streaming view mode ) // DatasetInfo represents dataset information in the TUI @@ -51,32 +56,42 @@ type DatasetInfo struct { // State holds the application state type State struct { - Jobs []Job - QueuedTasks []*Task - Datasets []DatasetInfo - JobList list.Model - GpuView viewport.Model - ContainerView viewport.Model - QueueView viewport.Model - SettingsView viewport.Model - DatasetView viewport.Model - ExperimentsView viewport.Model - Input textinput.Model - APIKeyInput textinput.Model - Status string - ErrorMsg string - InputMode bool - Width int - Height int - ShowHelp bool - Spinner spinner.Model - ActiveView ViewMode - LastRefresh time.Time - IsLoading bool - JobStats map[JobStatus]int - APIKey string - SettingsIndex int - Keys KeyMap + Jobs []Job + QueuedTasks []*Task + Datasets []DatasetInfo + JobList list.Model + GpuView viewport.Model + ContainerView viewport.Model + QueueView viewport.Model + SettingsView viewport.Model + DatasetView viewport.Model + ExperimentsView viewport.Model + NarrativeView viewport.Model + TeamView viewport.Model + ExperimentHistoryView viewport.Model + ConfigView viewport.Model + LogsView viewport.Model + SelectedJob Job + Input textinput.Model + APIKeyInput textinput.Model + Status string + ErrorMsg string + InputMode bool + Width int + Height int + ShowHelp bool + Spinner spinner.Model + ActiveView ViewMode + LastRefresh time.Time + LastFrameTime time.Time + RefreshRate float64 // measured in ms + FrameCount int + LastGPUUpdate time.Time + IsLoading bool + JobStats map[JobStatus]int + APIKey string + SettingsIndex int + Keys KeyMap } // InitialState creates the initial application state @@ -105,25 +120,54 @@ func InitialState(apiKey string) State { s.Style = SpinnerStyle() return State{ - JobList: jobList, - GpuView: viewport.New(0, 0), - ContainerView: viewport.New(0, 0), - QueueView: viewport.New(0, 0), - SettingsView: viewport.New(0, 0), - DatasetView: viewport.New(0, 0), - ExperimentsView: viewport.New(0, 0), - Input: input, - APIKeyInput: apiKeyInput, - Status: "Connected", - InputMode: false, - ShowHelp: false, - Spinner: s, - ActiveView: ViewModeJobs, - LastRefresh: time.Now(), - IsLoading: false, - JobStats: make(map[JobStatus]int), - APIKey: apiKey, - SettingsIndex: 0, - Keys: DefaultKeys(), + JobList: jobList, + GpuView: viewport.New(0, 0), + ContainerView: viewport.New(0, 0), + QueueView: viewport.New(0, 0), + SettingsView: viewport.New(0, 0), + DatasetView: viewport.New(0, 0), + ExperimentsView: viewport.New(0, 0), + NarrativeView: viewport.New(0, 0), + TeamView: viewport.New(0, 0), + ExperimentHistoryView: viewport.New(0, 0), + ConfigView: viewport.New(0, 0), + LogsView: viewport.New(0, 0), + Input: input, + APIKeyInput: apiKeyInput, + Status: "Connected", + InputMode: false, + ShowHelp: false, + Spinner: s, + ActiveView: ViewModeJobs, + LastRefresh: time.Now(), + IsLoading: false, + JobStats: make(map[JobStatus]int), + APIKey: apiKey, + SettingsIndex: 0, + Keys: DefaultKeys(), } } + +// LogMsg represents a log line from a job +type LogMsg struct { + JobName string `json:"job_name"` + Line string `json:"line"` + Time string `json:"time"` +} + +// JobUpdateMsg represents a real-time job status update via WebSocket +type JobUpdateMsg struct { + JobName string `json:"job_name"` + Status string `json:"status"` + TaskID string `json:"task_id"` + Progress int `json:"progress"` +} + +// GPUUpdateMsg represents a real-time GPU status update via WebSocket +type GPUUpdateMsg struct { + DeviceID int `json:"device_id"` + Utilization int `json:"utilization"` + MemoryUsed int64 `json:"memory_used"` + MemoryTotal int64 `json:"memory_total"` + Temperature int `json:"temperature"` +} diff --git a/cmd/tui/internal/model/styles.go b/cmd/tui/internal/model/styles.go index 35d9e0c..535a248 100644 --- a/cmd/tui/internal/model/styles.go +++ b/cmd/tui/internal/model/styles.go @@ -6,6 +6,38 @@ import ( "github.com/charmbracelet/lipgloss" ) +// Status colors for job list items +var ( + // StatusRunningColor is green for running jobs + StatusRunningColor = lipgloss.Color("#2ecc71") + // StatusPendingColor is yellow for pending jobs + StatusPendingColor = lipgloss.Color("#f1c40f") + // StatusFailedColor is red for failed jobs + StatusFailedColor = lipgloss.Color("#e74c3c") + // StatusFinishedColor is blue for completed jobs + StatusFinishedColor = lipgloss.Color("#3498db") + // StatusQueuedColor is gray for queued jobs + StatusQueuedColor = lipgloss.Color("#95a5a6") +) + +// StatusColor returns the color for a job status +func StatusColor(status JobStatus) lipgloss.Color { + switch status { + case StatusRunning: + return StatusRunningColor + case StatusPending: + return StatusPendingColor + case StatusFailed: + return StatusFailedColor + case StatusFinished: + return StatusFinishedColor + case StatusQueued: + return StatusQueuedColor + default: + return lipgloss.Color("#ffffff") + } +} + // NewJobListDelegate creates a styled delegate for the job list func NewJobListDelegate() list.DefaultDelegate { delegate := list.NewDefaultDelegate() diff --git a/cmd/tui/internal/services/export.go b/cmd/tui/internal/services/export.go new file mode 100644 index 0000000..03ce623 --- /dev/null +++ b/cmd/tui/internal/services/export.go @@ -0,0 +1,46 @@ +// Package services provides TUI service clients +package services + +import ( + "fmt" + "time" + + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// ExportService handles job export functionality for TUI +type ExportService struct { + serverURL string + apiKey string + logger *logging.Logger +} + +// NewExportService creates a new export service +func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportService { + return &ExportService{ + serverURL: serverURL, + apiKey: apiKey, + logger: logger, + } +} + +// ExportJob exports a job with optional anonymization +// Returns the path to the exported file +func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) { + s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize) + + // Placeholder - actual implementation would call API + // POST /api/jobs/{id}/export?anonymize=true + + exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix()) + + s.logger.Info("export complete", "job", jobName, "path", exportPath) + return exportPath, nil +} + +// ExportOptions contains options for export +type ExportOptions struct { + Anonymize bool + IncludeLogs bool + IncludeData bool +} diff --git a/cmd/tui/internal/services/services.go b/cmd/tui/internal/services/services.go index cf51a47..58f24ab 100644 --- a/cmd/tui/internal/services/services.go +++ b/cmd/tui/internal/services/services.go @@ -4,8 +4,12 @@ package services import ( "context" "fmt" + "os" + "path/filepath" + "time" "github.com/jfraeys/fetch_ml/cmd/tui/internal/config" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" "github.com/jfraeys/fetch_ml/internal/domain" "github.com/jfraeys/fetch_ml/internal/experiment" "github.com/jfraeys/fetch_ml/internal/network" @@ -21,6 +25,7 @@ type TaskQueue struct { *queue.TaskQueue // Embed to inherit all queue methods directly expManager *experiment.Manager ctx context.Context + config *config.Config } // NewTaskQueue creates a new task queue service @@ -37,14 +42,17 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) { return nil, fmt.Errorf("failed to create task queue: %w", err) } - // Initialize experiment manager - // TODO: Get base path from config - expManager := experiment.NewManager("./experiments") + // Initialize experiment manager with proper path + // BasePath already includes the mode-based experiments path (e.g., ./data/dev/experiments) + expDir := cfg.BasePath + os.MkdirAll(expDir, 0755) + expManager := experiment.NewManager(expDir) return &TaskQueue{ TaskQueue: internalQueue, expManager: expManager, ctx: context.Background(), + config: cfg, }, nil } @@ -94,20 +102,38 @@ func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) { return map[string]string{}, nil } -// ListDatasets retrieves available datasets (TUI-specific: currently returns empty) -func (tq *TaskQueue) ListDatasets() ([]struct { - Name string - SizeBytes int64 - Location string - LastAccess string -}, error) { - // This method doesn't exist in internal queue, return empty for now - return []struct { - Name string - SizeBytes int64 - Location string - LastAccess string - }{}, nil +// ListDatasets retrieves available datasets from the filesystem +func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) { + var datasets []model.DatasetInfo + + // Scan the active data directory for datasets + dataDir := tq.config.BasePath + if dataDir == "" { + return datasets, nil + } + + entries, err := os.ReadDir(dataDir) + if err != nil { + // Directory might not exist yet, return empty + return datasets, nil + } + + for _, entry := range entries { + if entry.IsDir() { + info, err := entry.Info() + if err != nil { + continue + } + datasets = append(datasets, model.DatasetInfo{ + Name: entry.Name(), + SizeBytes: info.Size(), + Location: filepath.Join(dataDir, entry.Name()), + LastAccess: time.Now(), + }) + } + } + + return datasets, nil } // ListExperiments retrieves experiment list diff --git a/cmd/tui/internal/services/websocket.go b/cmd/tui/internal/services/websocket.go new file mode 100644 index 0000000..670a725 --- /dev/null +++ b/cmd/tui/internal/services/websocket.go @@ -0,0 +1,275 @@ +// Package services provides TUI service clients +package services + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/gorilla/websocket" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" + "github.com/jfraeys/fetch_ml/internal/logging" +) + +// WebSocketClient manages real-time updates from the server +type WebSocketClient struct { + conn *websocket.Conn + serverURL string + apiKey string + logger *logging.Logger + + // Channels for different update types + jobUpdates chan model.JobUpdateMsg + gpuUpdates chan model.GPUUpdateMsg + statusUpdates chan model.StatusMsg + + // Control + ctx context.Context + cancel context.CancelFunc + connected bool +} + +// JobUpdateMsg represents a real-time job status update +type JobUpdateMsg struct { + JobName string `json:"job_name"` + Status string `json:"status"` + TaskID string `json:"task_id"` + Progress int `json:"progress"` +} + +// GPUUpdateMsg represents a real-time GPU status update +type GPUUpdateMsg struct { + DeviceID int `json:"device_id"` + Utilization int `json:"utilization"` + MemoryUsed int64 `json:"memory_used"` + MemoryTotal int64 `json:"memory_total"` + Temperature int `json:"temperature"` +} + +// NewWebSocketClient creates a new WebSocket client +func NewWebSocketClient(serverURL, apiKey string, logger *logging.Logger) *WebSocketClient { + ctx, cancel := context.WithCancel(context.Background()) + return &WebSocketClient{ + serverURL: serverURL, + apiKey: apiKey, + logger: logger, + jobUpdates: make(chan model.JobUpdateMsg, 100), + gpuUpdates: make(chan model.GPUUpdateMsg, 100), + statusUpdates: make(chan model.StatusMsg, 100), + ctx: ctx, + cancel: cancel, + } +} + +// Connect establishes the WebSocket connection +func (c *WebSocketClient) Connect() error { + // Parse server URL and construct WebSocket URL + u, err := url.Parse(c.serverURL) + if err != nil { + return fmt.Errorf("invalid server URL: %w", err) + } + + // Convert http/https to ws/wss + wsScheme := "ws" + if u.Scheme == "https" { + wsScheme = "wss" + } + wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host) + + // Create dialer with timeout + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + Subprotocols: []string{"fetchml-v1"}, + } + + // Add API key to headers + headers := http.Header{} + if c.apiKey != "" { + headers.Set("X-API-Key", c.apiKey) + } + + conn, resp, err := dialer.Dial(wsURL, headers) + if err != nil { + if resp != nil { + return fmt.Errorf("websocket dial failed (status %d): %w", resp.StatusCode, err) + } + return fmt.Errorf("websocket dial failed: %w", err) + } + + c.conn = conn + c.connected = true + c.logger.Info("websocket connected", "url", wsURL) + + // Start message handler + go c.messageHandler() + + // Start heartbeat + go c.heartbeat() + + return nil +} + +// Disconnect closes the WebSocket connection +func (c *WebSocketClient) Disconnect() { + c.cancel() + if c.conn != nil { + c.conn.Close() + } + c.connected = false +} + +// IsConnected returns true if connected +func (c *WebSocketClient) IsConnected() bool { + return c.connected +} + +// messageHandler reads messages from the WebSocket +func (c *WebSocketClient) messageHandler() { + for { + select { + case <-c.ctx.Done(): + return + default: + } + + if c.conn == nil { + time.Sleep(100 * time.Millisecond) + continue + } + + // Set read deadline + c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + // Read message + messageType, data, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + c.logger.Error("websocket read error", "error", err) + } + c.connected = false + + // Attempt reconnect + time.Sleep(5 * time.Second) + if err := c.Connect(); err != nil { + c.logger.Error("websocket reconnect failed", "error", err) + } + continue + } + + // Handle binary vs text messages + if messageType == websocket.BinaryMessage { + c.handleBinaryMessage(data) + } else { + c.handleTextMessage(data) + } + } +} + +// handleBinaryMessage handles binary WebSocket messages +func (c *WebSocketClient) handleBinaryMessage(data []byte) { + if len(data) < 2 { + return + } + + // Binary protocol: [opcode:1][data...] + opcode := data[0] + payload := data[1:] + + switch opcode { + case 0x01: // Job update + var update JobUpdateMsg + if err := json.Unmarshal(payload, &update); err != nil { + c.logger.Error("failed to unmarshal job update", "error", err) + return + } + c.jobUpdates <- model.JobUpdateMsg(update) + + case 0x02: // GPU update + var update GPUUpdateMsg + if err := json.Unmarshal(payload, &update); err != nil { + c.logger.Error("failed to unmarshal GPU update", "error", err) + return + } + c.gpuUpdates <- model.GPUUpdateMsg(update) + + case 0x03: // Status message + var status model.StatusMsg + if err := json.Unmarshal(payload, &status); err != nil { + c.logger.Error("failed to unmarshal status", "error", err) + return + } + c.statusUpdates <- status + } +} + +// handleTextMessage handles text WebSocket messages (JSON) +func (c *WebSocketClient) handleTextMessage(data []byte) { + var msg map[string]interface{} + if err := json.Unmarshal(data, &msg); err != nil { + c.logger.Error("failed to unmarshal text message", "error", err) + return + } + + msgType, _ := msg["type"].(string) + switch msgType { + case "job_update": + // Handle JSON job updates + case "gpu_update": + // Handle JSON GPU updates + case "status": + // Handle status messages + } +} + +// heartbeat sends periodic ping messages +func (c *WebSocketClient) heartbeat() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if c.conn != nil { + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + c.logger.Error("websocket ping failed", "error", err) + c.connected = false + } + } + } + } +} + +// Subscribe subscribes to specific update channels +func (c *WebSocketClient) Subscribe(channels ...string) error { + if !c.connected { + return fmt.Errorf("not connected") + } + + subMsg := map[string]interface{}{ + "action": "subscribe", + "channels": channels, + } + + data, _ := json.Marshal(subMsg) + return c.conn.WriteMessage(websocket.TextMessage, data) +} + +// GetJobUpdates returns the job updates channel +func (c *WebSocketClient) GetJobUpdates() <-chan model.JobUpdateMsg { + return c.jobUpdates +} + +// GetGPUUpdates returns the GPU updates channel +func (c *WebSocketClient) GetGPUUpdates() <-chan model.GPUUpdateMsg { + return c.gpuUpdates +} + +// GetStatusUpdates returns the status updates channel +func (c *WebSocketClient) GetStatusUpdates() <-chan model.StatusMsg { + return c.statusUpdates +} diff --git a/cmd/tui/internal/view/narrative_view.go b/cmd/tui/internal/view/narrative_view.go new file mode 100644 index 0000000..4b3fb68 --- /dev/null +++ b/cmd/tui/internal/view/narrative_view.go @@ -0,0 +1,147 @@ +// Package view provides TUI rendering functionality +package view + +import ( + "strings" + + "github.com/charmbracelet/lipgloss" + "github.com/jfraeys/fetch_ml/cmd/tui/internal/model" +) + +// NarrativeView displays job research context and outcome +type NarrativeView struct { + Width int + Height int +} + +// View renders the narrative/outcome for a job +func (v *NarrativeView) View(job model.Job) string { + if job.Hypothesis == "" && job.Context == "" && job.Intent == "" { + return v.renderEmpty() + } + + var sections []string + + // Title + titleStyle := lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color("#ff9e64")). + MarginBottom(1) + sections = append(sections, titleStyle.Render("๐Ÿ“Š Research Context")) + + // Hypothesis + if job.Hypothesis != "" { + sections = append(sections, v.renderSection("๐Ÿงช Hypothesis", job.Hypothesis)) + } + + // Context + if job.Context != "" { + sections = append(sections, v.renderSection("๐Ÿ“š Context", job.Context)) + } + + // Intent + if job.Intent != "" { + sections = append(sections, v.renderSection("๐ŸŽฏ Intent", job.Intent)) + } + + // Expected Outcome + if job.ExpectedOutcome != "" { + sections = append(sections, v.renderSection("๐Ÿ“ˆ Expected Outcome", job.ExpectedOutcome)) + } + + // Actual Outcome (if available) + if job.ActualOutcome != "" { + statusIcon := "โ“" + switch job.OutcomeStatus { + case "validated": + statusIcon = "โœ…" + case "invalidated": + statusIcon = "โŒ" + case "inconclusive": + statusIcon = "โš–๏ธ" + case "partial": + statusIcon = "๐Ÿ“" + } + sections = append(sections, v.renderSection(statusIcon+" Actual Outcome", job.ActualOutcome)) + } + + // Combine all sections + content := lipgloss.JoinVertical(lipgloss.Left, sections...) + + // Apply border and padding + style := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color("#7aa2f7")). + Padding(1, 2). + Width(v.Width - 4) + + return style.Render(content) +} + +func (v *NarrativeView) renderSection(title, content string) string { + titleStyle := lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color("#7aa2f7")) + + contentStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#d8dee9")). + MarginLeft(2) + + // Wrap long content + wrapped := wrapText(content, v.Width-12) + + return lipgloss.JoinVertical(lipgloss.Left, + titleStyle.Render(title), + contentStyle.Render(wrapped), + "", + ) +} + +func (v *NarrativeView) renderEmpty() string { + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#88c0d0")). + Italic(true). + Padding(2) + + return style.Render("No narrative data available for this job.\n\n" + + "Use --hypothesis, --context, --intent flags when queuing jobs.") +} + +// wrapText wraps text to maxWidth, preserving newlines +func wrapText(text string, maxWidth int) string { + if maxWidth <= 0 { + return text + } + + lines := strings.Split(text, "\n") + var result []string + + for _, line := range lines { + if len(line) <= maxWidth { + result = append(result, line) + continue + } + + // Simple word wrap + words := strings.Fields(line) + var currentLine string + for _, word := range words { + if len(currentLine)+len(word)+1 > maxWidth { + if currentLine != "" { + result = append(result, currentLine) + } + currentLine = word + } else { + if currentLine != "" { + currentLine += " " + } + currentLine += word + } + } + if currentLine != "" { + result = append(result, currentLine) + } + } + + return strings.Join(result, "\n") +} diff --git a/cmd/tui/internal/view/view.go b/cmd/tui/internal/view/view.go index 6957823..9320e9c 100644 --- a/cmd/tui/internal/view/view.go +++ b/cmd/tui/internal/view/view.go @@ -2,6 +2,7 @@ package view import ( + "fmt" "strings" "github.com/charmbracelet/lipgloss" @@ -180,6 +181,27 @@ func getRightPanel(m model.State, width int) string { style = activeBorderStyle viewTitle = "๐Ÿ“ฆ Datasets" content = m.DatasetView.View() + case model.ViewModeNarrative: + style = activeBorderStyle + viewTitle = "๐Ÿ“Š Research Context" + narrativeView := &NarrativeView{Width: width, Height: m.Height} + content = narrativeView.View(m.SelectedJob) + case model.ViewModeTeam: + style = activeBorderStyle + viewTitle = "๐Ÿ‘ฅ Team Jobs" + content = "Team collaboration view - shows jobs from all team members\n\n(Requires API: GET /api/jobs?all_users=true)" + case model.ViewModeExperimentHistory: + style = activeBorderStyle + viewTitle = "๐Ÿ“œ Experiment History" + content = m.ExperimentHistoryView.View() + case model.ViewModeConfig: + style = activeBorderStyle + viewTitle = "โš™๏ธ Config" + content = m.ConfigView.View() + case model.ViewModeLogs: + style = activeBorderStyle + viewTitle = "๐Ÿ“œ Logs" + content = m.LogsView.View() default: viewTitle = "๐Ÿ“Š System Overview" content = getOverviewPanel(m) @@ -218,7 +240,14 @@ func getStatusBar(m model.State) string { if m.ShowHelp { statusText = "Press 'h' to hide help" } - return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText) + + // Add refresh rate indicator + refreshInfo := "" + if m.RefreshRate > 0 { + refreshInfo = fmt.Sprintf(" | %.0fms", m.RefreshRate) + } + + return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText + refreshInfo) } func helpText(m model.State) string { @@ -242,25 +271,28 @@ func helpText(m model.State) string { โ•‘ Navigation โ•‘ โ•‘ j/k, โ†‘/โ†“ : Move selection / : Filter jobs โ•‘ โ•‘ 1 : Job list view 2 : Datasets view โ•‘ -โ•‘ 3 : Experiments view v : Queue view โ•‘ +โ•‘ 3 : Experiments view q : Queue view โ•‘ โ•‘ g : GPU view o : Container view โ•‘ -โ•‘ s : Settings view โ•‘ +โ•‘ s : Settings view n : Narrative view โ•‘ +โ•‘ m : Team view e : Experiment history โ•‘ +โ•‘ c : Config view l : Logs (job stream) โ•‘ +โ•‘ E : Export job @ : Filter by team โ•‘ โ•‘ โ•‘ โ•‘ Actions โ•‘ โ•‘ t : Queue job a : Queue w/ args โ•‘ -โ•‘ c : Cancel task d : Delete pending โ•‘ +โ•‘ x : Cancel task d : Delete pending โ•‘ โ•‘ f : Mark as failed r : Refresh all โ•‘ โ•‘ G : Refresh GPU only โ•‘ โ•‘ โ•‘ โ•‘ General โ•‘ -โ•‘ h or ? : Toggle this help q/Ctrl+C : Quit โ•‘ +โ•‘ h or ? : Toggle this help Ctrl+C : Quit โ•‘ โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•` } func getQuickHelp(m model.State) string { if m.ActiveView == model.ViewModeSettings { - return " โ†‘/โ†“:move enter:select esc:exit settings q:quit" + return " โ†‘/โ†“:move enter:select esc:exit settings Ctrl+C:quit" } - return " h:help 1:jobs 2:datasets 3:experiments v:queue g:gpu o:containers " + - "s:settings t:queue r:refresh q:quit" + return " h:help 1:jobs 2:datasets 3:experiments q:queue g:gpu o:containers " + + "n:narrative m:team e:history c:config l:logs E:export @:filter s:settings t:trigger r:refresh Ctrl+C:quit" } diff --git a/cmd/tui/main.go b/cmd/tui/main.go index b6fb2ad..573afd0 100644 --- a/cmd/tui/main.go +++ b/cmd/tui/main.go @@ -5,6 +5,7 @@ import ( "log" "os" "os/signal" + "path/filepath" "syscall" tea "github.com/charmbracelet/bubbletea" @@ -41,6 +42,16 @@ func (m AppModel) View() string { } func main() { + // Redirect logs to file to prevent TUI disruption + homeDir, _ := os.UserHomeDir() + logDir := filepath.Join(homeDir, ".ml", "logs") + os.MkdirAll(logDir, 0755) + logFile, logErr := os.OpenFile(filepath.Join(logDir, "tui.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if logErr == nil { + log.SetOutput(logFile) + defer logFile.Close() + } + // Parse authentication flags authFlags := auth.ParseAuthFlags() if err := auth.ValidateFlags(authFlags); err != nil { @@ -85,10 +96,10 @@ func main() { log.Printf(" 4. Run TUI: ./bin/tui") log.Printf("") log.Printf("Example ~/.ml/config.toml:") - log.Printf(" worker_host = \"localhost\"") - log.Printf(" worker_user = \"your_username\"") - log.Printf(" worker_base = \"~/ml_jobs\"") - log.Printf(" worker_port = 22") + log.Printf(" mode = \"dev\"") + log.Printf(" # Paths auto-resolve based on mode:") + log.Printf(" # dev mode: ./data/dev/experiments") + log.Printf(" # prod mode: ./data/prod/experiments") log.Printf(" api_key = \"your_api_key_here\"") log.Printf("") log.Printf("For more help, see: https://github.com/jfraeys/fetch_ml/docs") @@ -101,6 +112,11 @@ func main() { } log.Printf("Loaded TOML configuration from %s", cliConfPath) + // Force local mode - TUI runs on server with direct filesystem access + cfg.Host = "" + // Clear BasePath to force mode-based path resolution + cfg.BasePath = "" + // Validate authentication configuration if err := cfg.Auth.ValidateAuthConfig(); err != nil { log.Fatalf("Invalid authentication configuration: %v", err) @@ -130,7 +146,7 @@ func main() { srv, err := services.NewMLServer(cfg) if err != nil { - log.Fatalf("Failed to connect to server: %v", err) + log.Fatalf("Failed to initialize local server: %v", err) } defer func() { if err := srv.Close(); err != nil { @@ -138,19 +154,24 @@ func main() { } }() + // TaskQueue is optional for local mode tq, err := services.NewTaskQueue(cfg) if err != nil { - log.Printf("Failed to connect to Redis: %v", err) - return + log.Printf("Warning: Failed to connect to Redis: %v", err) + log.Printf("Continuing without task queue functionality") + tq = nil + } + if tq != nil { + defer func() { + if err := tq.Close(); err != nil { + log.Printf("task queue close error: %v", err) + } + }() } - defer func() { - if err := tq.Close(); err != nil { - log.Printf("task queue close error: %v", err) - } - }() - // Initialize logger with error level and no debug output - logger := logging.NewLogger(-4, false) // -4 = slog.LevelError + // Initialize logger with file output only (prevents TUI disruption) + logFilePath := filepath.Join(logDir, "tui.log") + logger := logging.NewFileLogger(-4, false, logFilePath) // -4 = slog.LevelError // Initialize State and Controller var effectiveAPIKey string @@ -177,14 +198,17 @@ func main() { go func() { <-sigChan + logger.Info("Received shutdown signal, closing TUI...") p.Quit() }() if _, err := p.Run(); err != nil { _ = p.ReleaseTerminal() - log.Printf("Error running TUI: %v", err) + logger.Error("Error running TUI", "error", err) return } + // Ensure terminal is released and resources are closed via defer statements _ = p.ReleaseTerminal() + logger.Info("TUI shutdown complete") } diff --git a/configs/api/dev-local.yaml b/configs/api/dev-local.yaml index f0ca12c..bab91e4 100644 --- a/configs/api/dev-local.yaml +++ b/configs/api/dev-local.yaml @@ -19,7 +19,7 @@ security: api_key_rotation_days: 90 audit_logging: enabled: true - log_path: "/tmp/fetchml-audit.log" + log_path: "data/dev/logs/fetchml-audit.log" rate_limit: enabled: false requests_per_minute: 60 @@ -46,8 +46,8 @@ database: logging: level: "info" - file: "" - audit_log: "" + file: "data/dev/logs/fetchml.log" + audit_log: "data/dev/logs/fetchml-audit.log" resources: max_workers: 1 diff --git a/configs/api/dev.yaml b/configs/api/dev.yaml index e548986..0f5a978 100644 --- a/configs/api/dev.yaml +++ b/configs/api/dev.yaml @@ -1,6 +1,6 @@ -base_path: "./data/experiments" +base_path: "./data/dev/experiments" -data_dir: "./data/active" +data_dir: "./data/dev/active" auth: enabled: false @@ -19,7 +19,7 @@ security: api_key_rotation_days: 90 audit_logging: enabled: true - log_path: "./data/fetchml-audit.log" + log_path: "./data/dev/logs/fetchml-audit.log" rate_limit: enabled: false requests_per_minute: 60 @@ -42,12 +42,12 @@ redis: database: type: "sqlite" - connection: "./data/fetchml.sqlite" + connection: "./data/dev/fetchml.sqlite" logging: level: "info" - file: "" - audit_log: "" + file: "./data/dev/logs/fetchml.log" + audit_log: "./data/dev/logs/fetchml-audit.log" resources: max_workers: 1 diff --git a/configs/api/prod.yaml b/configs/api/prod.yaml index 726d1d2..914ae8d 100644 --- a/configs/api/prod.yaml +++ b/configs/api/prod.yaml @@ -1,6 +1,6 @@ -base_path: "/app/data/experiments" +base_path: "/app/data/prod/experiments" -data_dir: "/data/active" +data_dir: "/app/data/prod/active" auth: enabled: true @@ -45,12 +45,12 @@ redis: database: type: "sqlite" - connection: "/app/data/experiments/fetch_ml.sqlite" + connection: "/app/data/prod/fetch_ml.sqlite" logging: level: "info" - file: "/logs/fetch_ml.log" - audit_log: "" + file: "/app/data/prod/logs/fetch_ml.log" + audit_log: "/app/data/prod/logs/audit.log" resources: max_workers: 2 diff --git a/internal/api/routes.go b/internal/api/routes.go index ac96eb2..d8df8ed 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -36,6 +36,22 @@ func (s *Server) registerRoutes(mux *http.ServeMux) { // Register HTTP API handlers s.handlers.RegisterHandlers(mux) + + // Register new REST API endpoints for TUI + jobsHandler := jobs.NewHandler( + s.expManager, + s.logger, + s.taskQueue, + s.db, + s.config.BuildAuthConfig(), + nil, + ) + + // Experiment history endpoint: GET /api/experiments/:id/history + mux.HandleFunc("GET /api/experiments/{id}/history", jobsHandler.GetExperimentHistoryHTTP) + + // Team jobs endpoint: GET /api/jobs?all_users=true + mux.HandleFunc("GET /api/jobs", jobsHandler.ListAllJobsHTTP) } // registerHealthRoutes sets up health check endpoints diff --git a/internal/api/ws/handler.go b/internal/api/ws/handler.go index eb224b5..daafaa0 100644 --- a/internal/api/ws/handler.go +++ b/internal/api/ws/handler.go @@ -12,6 +12,8 @@ import ( "os" "path/filepath" "strings" + "sync" + "time" "github.com/gorilla/websocket" "github.com/jfraeys/fetch_ml/internal/audit" @@ -110,6 +112,22 @@ const ( PermJupyterRead = "jupyter:read" ) +// ClientType represents the type of WebSocket client +type ClientType int + +const ( + ClientTypeCLI ClientType = iota + ClientTypeTUI +) + +// Client represents a connected WebSocket client +type Client struct { + conn *websocket.Conn + Type ClientType + User string + RemoteAddr string +} + // Handler provides WebSocket handling type Handler struct { authConfig *auth.Config @@ -125,6 +143,10 @@ type Handler struct { jobsHandler *jobs.Handler jupyterHandler *jupyterj.Handler datasetsHandler *datasets.Handler + + // Client management for push updates + clients map[*Client]bool + clientsMu sync.RWMutex } // NewHandler creates a new WebSocket handler @@ -158,6 +180,7 @@ func NewHandler( jobsHandler: jobsHandler, jupyterHandler: jupyterHandler, datasetsHandler: datasetsHandler, + clients: make(map[*Client]bool), } } @@ -217,6 +240,25 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleConnection(conn *websocket.Conn) { h.logger.Info("websocket connection established", "remote", conn.RemoteAddr()) + // Register client + client := &Client{ + conn: conn, + Type: ClientTypeTUI, // Assume TUI for now, could detect from handshake + User: "tui-user", + RemoteAddr: conn.RemoteAddr().String(), + } + + h.clientsMu.Lock() + h.clients[client] = true + h.clientsMu.Unlock() + + defer func() { + h.clientsMu.Lock() + delete(h.clients, client) + h.clientsMu.Unlock() + conn.Close() + }() + for { messageType, payload, err := conn.ReadMessage() if err != nil { @@ -765,3 +807,66 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro "message": "Outcome updated", }) } + +// BroadcastJobUpdate sends job status update to all connected TUI clients +func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) { + h.clientsMu.RLock() + defer h.clientsMu.RUnlock() + + msg := map[string]any{ + "type": "job_update", + "job_name": jobName, + "status": status, + "progress": progress, + "time": time.Now().Unix(), + } + + payload, _ := json.Marshal(msg) + + for client := range h.clients { + if client.Type == ClientTypeTUI { + if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil { + h.logger.Warn("failed to broadcast to client", "error", err, "client", client.RemoteAddr) + } + } + } +} + +// BroadcastGPUUpdate sends GPU status update to all connected TUI clients +func (h *Handler) BroadcastGPUUpdate(deviceID, utilization int, memoryUsed, memoryTotal int64) { + h.clientsMu.RLock() + defer h.clientsMu.RUnlock() + + msg := map[string]any{ + "type": "gpu_update", + "device_id": deviceID, + "utilization": utilization, + "memory_used": memoryUsed, + "memory_total": memoryTotal, + "time": time.Now().Unix(), + } + + payload, _ := json.Marshal(msg) + + for client := range h.clients { + if client.Type == ClientTypeTUI { + if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil { + h.logger.Warn("failed to broadcast GPU update", "error", err, "client", client.RemoteAddr) + } + } + } +} + +// GetConnectedClientCount returns the number of connected TUI clients +func (h *Handler) GetConnectedClientCount() int { + h.clientsMu.RLock() + defer h.clientsMu.RUnlock() + + count := 0 + for client := range h.clients { + if client.Type == ClientTypeTUI { + count++ + } + } + return count +} diff --git a/internal/config/smart_defaults.go b/internal/config/smart_defaults.go index 5c218df..c255912 100644 --- a/internal/config/smart_defaults.go +++ b/internal/config/smart_defaults.go @@ -229,6 +229,37 @@ func (s *SmartDefaults) ExpandPath(path string) string { return path } +// ModeBasedPath returns the appropriate path for a given mode +func ModeBasedPath(mode string, subpath string) string { + switch mode { + case "dev": + return filepath.Join("./data/dev", subpath) + case "prod": + return filepath.Join("./data/prod", subpath) + case "ci": + return filepath.Join("./data/ci", subpath) + case "prod-smoke": + return filepath.Join("./data/prod-smoke", subpath) + default: + return filepath.Join("./data/dev", subpath) + } +} + +// ModeBasedBasePath returns the experiments base path for a given mode +func ModeBasedBasePath(mode string) string { + return ModeBasedPath(mode, "experiments") +} + +// ModeBasedDataDir returns the data directory for a given mode +func ModeBasedDataDir(mode string) string { + return ModeBasedPath(mode, "active") +} + +// ModeBasedLogDir returns the logs directory for a given mode +func ModeBasedLogDir(mode string) string { + return ModeBasedPath(mode, "logs") +} + // GetEnvironmentDescription returns a human-readable description func (s *SmartDefaults) GetEnvironmentDescription() string { switch s.Profile {