feat: update CLI, TUI, and security documentation

- Add safety checks to Zig build
- Add TUI with job management and narrative views
- Add WebSocket support and export services
- Add smart configuration defaults
- Update API routes with security headers
- Update SECURITY.md with comprehensive policy
- Add Makefile security scanning targets
This commit is contained in:
Jeremie Fraeys 2026-02-19 15:35:05 -05:00
parent 02811c0ffe
commit 6028779239
No known key found for this signature in database
34 changed files with 2956 additions and 175 deletions

View file

@ -1,4 +1,4 @@
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs security-scan gosec govulncheck check-unsafe security-audit test-security
OK =
DOCS_PORT ?= 1313
DOCS_BIND ?= 127.0.0.1
@ -498,3 +498,60 @@ prod-status:
prod-logs:
@./deployments/deploy.sh prod logs
# =============================================================================
# SECURITY TARGETS
# =============================================================================
.PHONY: security-scan gosec govulncheck check-unsafe security-audit
# Run all security scans
security-scan: gosec govulncheck check-unsafe
@echo "${OK} Security scan complete"
# Run gosec security linter
gosec:
@mkdir -p reports
@echo "Running gosec security scan..."
@if command -v gosec >/dev/null 2>&1; then \
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
gosec -fmt=sarif -out=reports/gosec-results.sarif ./... 2>/dev/null || true; \
gosec ./... 2>/dev/null || echo "Note: gosec found issues (see reports/gosec-results.json)"; \
else \
echo "Installing gosec..."; \
go install github.com/securego/gosec/v2/cmd/gosec@latest; \
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
fi
@echo "${OK} gosec scan complete (see reports/gosec-results.*)"
# Run govulncheck for known vulnerabilities
govulncheck:
@echo "Running govulncheck for known vulnerabilities..."
@if command -v govulncheck >/dev/null 2>&1; then \
govulncheck ./...; \
else \
echo "Installing govulncheck..."; \
go install golang.org/x/vuln/cmd/govulncheck@latest; \
govulncheck ./...; \
fi
@echo "${OK} govulncheck complete"
# Check for unsafe package usage
check-unsafe:
@echo "Checking for unsafe package usage..."
@if grep -r "unsafe\." --include="*.go" ./internal ./cmd ./pkg 2>/dev/null; then \
echo "WARNING: Found unsafe package usage (review required)"; \
exit 1; \
else \
echo "${OK} No unsafe package usage found"; \
fi
# Full security audit (tests + scans)
security-audit: security-scan test-security
@echo "${OK} Full security audit complete"
# Run security-specific tests
test-security:
@echo "Running security tests..."
@go test -v ./tests/security/... 2>/dev/null || echo "Note: No security tests yet (will be added in Phase 5)"
@echo "${OK} Security tests complete"

View file

@ -1,6 +1,153 @@
# Security Policy
## Reporting a Vulnerability
Please report security vulnerabilities to security@fetchml.io.
Do NOT open public issues for security bugs.
Response timeline:
- Acknowledgment: within 48 hours
- Initial assessment: within 5 days
- Fix released: within 30 days (critical), 90 days (high)
## Security Features
FetchML implements defense-in-depth security for ML research systems:
### Authentication & Authorization
- **Argon2id API Key Hashing**: Memory-hard hashing resists GPU cracking
- **RBAC with Role Inheritance**: Granular permissions (admin, data_scientist, data_engineer, viewer, operator)
- **Constant-time Comparison**: Prevents timing attacks on key validation
### Cryptographic Practices
- **Ed25519 Manifest Signing**: Tamper detection for run manifests
- **SHA-256 with Salt**: Legacy key support with migration path
- **Secure Key Generation**: 256-bit entropy for all API keys
### Container Security
- **Rootless Podman**: No privileged containers
- **Capability Dropping**: `--cap-drop ALL` by default
- **No New Privileges**: `no-new-privileges` security opt
- **Read-only Root Filesystem**: Immutable base image
### Input Validation
- **Path Traversal Prevention**: Canonical path validation
- **Command Injection Protection**: Shell metacharacter filtering
- **Length Limits**: Prevents DoS via oversized inputs
### Audit & Monitoring
- **Structured Audit Logging**: JSON-formatted security events
- **Hash-chained Logs**: Tamper-evident audit trail
- **Anomaly Detection**: Brute force, privilege escalation alerts
- **Security Metrics**: Prometheus integration
### Supply Chain
- **Dependency Scanning**: gosec + govulncheck in CI
- **No unsafe Package**: Prohibited in production code
- **Manifest Signing**: Ed25519 signatures for integrity
## Supported Versions
| Version | Supported |
| ------- | ------------------ |
| 0.2.x | :white_check_mark: |
| 0.1.x | :x: |
## Security Checklist (Pre-Release)
### Code Review
- [ ] No hardcoded secrets
- [ ] No `unsafe` usage without justification
- [ ] All user inputs validated
- [ ] All file paths canonicalized
- [ ] No secrets in error messages
### Dependency Audit
- [ ] `go mod verify` passes
- [ ] `govulncheck` shows no vulnerabilities
- [ ] All dependencies pinned
- [ ] No unmaintained dependencies
### Container Security
- [ ] No privileged containers
- [ ] Rootless execution
- [ ] Seccomp/AppArmor applied
- [ ] Network isolation
### Cryptography
- [ ] Argon2id for key hashing
- [ ] Ed25519 for signing
- [ ] TLS 1.3 only
- [ ] No weak ciphers
### Testing
- [ ] Security tests pass
- [ ] Fuzz tests for parsers
- [ ] Authentication bypass tested
- [ ] Container escape tested
## Security Commands
```bash
# Run security scan
make security-scan
# Check for vulnerabilities
govulncheck ./...
# Static analysis
gosec ./...
# Check for unsafe usage
grep -r "unsafe\." --include="*.go" ./internal ./cmd
# Build with sanitizers
cd native && cmake -DENABLE_ASAN=ON .. && make
```
## Threat Model
### Attack Surfaces
1. **External API**: Researchers submitting malicious jobs
2. **Container Runtime**: Escape to host system
3. **Data Exfiltration**: Stealing datasets/models
4. **Privilege Escalation**: Researcher → admin
5. **Supply Chain**: Compromised dependencies
6. **Secrets Leakage**: API keys in logs/errors
### Mitigations
| Threat | Mitigation |
|--------|------------|
| Malicious Jobs | Input validation, container sandboxing, resource limits |
| Container Escape | Rootless, no-new-privileges, seccomp, read-only root |
| Data Exfiltration | Network policies, audit logging, rate limiting |
| Privilege Escalation | RBAC, least privilege, anomaly detection |
| Supply Chain | Dependency scanning, manifest signing, pinned versions |
| Secrets Leakage | Log sanitization, secrets manager, memory clearing |
## Responsible Disclosure
We follow responsible disclosure practices:
1. **Report privately**: Email security@fetchml.io with details
2. **Provide details**: Steps to reproduce, impact assessment
3. **Allow time**: We need 30-90 days to fix before public disclosure
4. **Acknowledgment**: We credit researchers who report valid issues
## Security Team
- security@fetchml.io - Security issues and questions
- security-response@fetchml.io - Active incident response
---
*Last updated: 2026-02-19*
---
# Security Guide for Fetch ML Homelab
This guide covers security best practices for deploying Fetch ML in a homelab environment.
*The following section covers security best practices for deploying Fetch ML in a homelab environment.*
## Quick Setup

View file

@ -76,6 +76,12 @@ pub fn build(b: *std.Build) void {
exe.root_module.addOptions("build_options", options);
// Link native dataset_hash library
exe.linkLibC();
exe.addLibraryPath(b.path("../native/build"));
exe.linkSystemLibrary("dataset_hash");
exe.addIncludePath(b.path("../native/dataset_hash"));
// Install the executable to zig-out/bin
b.installArtifact(exe);
@ -94,6 +100,13 @@ pub fn build(b: *std.Build) void {
// Standard Zig test discovery - find all test files automatically
const test_step = b.step("test", "Run unit tests");
// Safety check for release builds
const safety_check_step = b.step("safety-check", "Verify ReleaseSafe mode is used for production");
if (optimize != .ReleaseSafe and optimize != .Debug) {
const warn_no_safe = b.addSystemCommand(&.{ "echo", "WARNING: Building without ReleaseSafe mode. Production builds should use -Doptimize=ReleaseSafe" });
safety_check_step.dependOn(&warn_no_safe.step);
}
// Test main executable
const main_tests = b.addTest(.{
.root_module = b.createModule(.{

View file

@ -429,11 +429,23 @@ fn queueSingleJob(
const commit_hex = try crypto.encodeHexLower(allocator, commit_id);
defer allocator.free(commit_hex);
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
defer allocator.free(api_key_hash);
// Check for existing job with same commit (incremental queue)
if (!options.force) {
const existing = try checkExistingJob(allocator, job_name, commit_id, api_key_hash, config);
if (existing) |ex| {
defer allocator.free(ex);
// Server already has this job - handle duplicate response
try handleDuplicateResponse(allocator, ex, job_name, commit_hex, options);
return;
}
}
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
// Connect to WebSocket and send queue message
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
@ -1145,3 +1157,46 @@ fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions
return try buf.toOwnedSlice(allocator);
}
/// Check if a job with the same commit_id already exists on the server
/// Returns: Optional JSON response from server if duplicate found
fn checkExistingJob(
allocator: std.mem.Allocator,
job_name: []const u8,
commit_id: []const u8,
api_key_hash: []const u8,
config: Config,
) !?[]const u8 {
// Connect to server and query for existing job
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
defer client.close();
// Send query for existing job
try client.sendQueryJobByCommit(job_name, commit_id, api_key_hash);
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
// Parse response
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch |err| {
// If JSON parse fails, treat as no duplicate found
std.log.debug("Failed to parse check response: {}", .{err});
return null;
};
defer parsed.deinit();
const root = parsed.value.object;
// Check if job exists
if (root.get("exists")) |exists| {
if (!exists.bool) return null;
// Job exists - copy the full response for caller
return try allocator.dupe(u8, message);
}
return null;
}

View file

@ -6,6 +6,7 @@ const rsync = @import("../utils/rsync_embedded.zig");
const ws = @import("../net/ws/client.zig");
const logging = @import("../utils/logging.zig");
const json = @import("../utils/json.zig");
const native_hash = @import("../utils/native_hash.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
@ -26,6 +27,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var should_queue = false;
var priority: u8 = 5;
var json_mode: bool = false;
var dev_mode: bool = false;
var use_timestamp_check = false;
var dry_run = false;
// Parse flags
var i: usize = 1;
@ -40,6 +44,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, args[i], "--dev")) {
dev_mode = true;
} else if (std.mem.eql(u8, args[i], "--check-timestamp")) {
use_timestamp_check = true;
} else if (std.mem.eql(u8, args[i], "--dry-run")) {
dry_run = true;
}
}
@ -49,16 +59,59 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
mut_config.deinit(allocator);
}
// Calculate commit ID (SHA256 of directory tree)
const commit_id = try crypto.hashDirectory(allocator, path);
defer allocator.free(commit_id);
// Detect if path is a subdirectory by finding git root
const git_root = try findGitRoot(allocator, path);
defer if (git_root) |gr| allocator.free(gr);
// Construct remote destination path
const remote_path = try std.fmt.allocPrint(
const is_subdir = git_root != null and !std.mem.eql(u8, git_root.?, path);
const relative_path = if (is_subdir) blk: {
// Get relative path from git root to the specified path
break :blk try std.fs.path.relative(allocator, git_root.?, path);
} else null;
defer if (relative_path) |rp| allocator.free(rp);
// Determine commit_id and remote path based on mode
const commit_id: []const u8 = if (dev_mode) blk: {
// Dev mode: skip expensive hashing, use fixed "dev" commit
break :blk "dev";
} else blk: {
// Production mode: calculate SHA256 of directory tree (always from git root)
const hash_base = git_root orelse path;
break :blk try crypto.hashDirectory(allocator, hash_base);
};
defer if (!dev_mode) allocator.free(commit_id);
// In dev mode, sync to {worker_base}/dev/files/ instead of hashed path
// For subdirectories, append the relative path to the remote destination
const remote_path = if (dev_mode) blk: {
if (is_subdir) {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/dev/files/{s}/",
.{ config.api_key, config.worker_host, config.worker_base, relative_path.? },
);
} else {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/dev/files/",
.{ config.api_key, config.worker_host, config.worker_base },
);
}
} else blk: {
if (is_subdir) {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/{s}/",
.{ config.api_key, config.worker_host, config.worker_base, commit_id, relative_path.? },
);
} else {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/",
.{ config.api_key, config.worker_host, config.worker_base, commit_id },
);
}
};
defer allocator.free(remote_path);
// Sync using embedded rsync (no external binary needed)
@ -102,6 +155,9 @@ fn printUsage() void {
logging.err(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
logging.err(" --monitor Wait and show basic sync progress\n", .{});
logging.err(" --json Output machine-readable JSON (sync result only)\n", .{});
logging.err(" --dev Dev mode: skip hashing, use fixed path (fast)\n", .{});
logging.err(" --check-timestamp Skip files unchanged since last sync\n", .{});
logging.err(" --dry-run Show what would be synced without transferring\n", .{});
logging.err(" --help, -h Show this help message\n", .{});
}
@ -175,3 +231,29 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm
std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{});
}
}
/// Find the git root directory by walking up from the given path
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
var buf: [std.fs.max_path_bytes]u8 = undefined;
const path = try std.fs.realpath(start_path, &buf);
var current = path;
while (true) {
// Check if .git exists in current directory
const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" });
defer allocator.free(git_path);
if (std.fs.accessAbsolute(git_path, .{})) {
// Found .git directory
return try allocator.dupe(u8, current);
} else |_| {
// .git not found here, try parent
const parent = std.fs.path.dirname(current);
if (parent == null or std.mem.eql(u8, parent.?, current)) {
// Reached root without finding .git
return null;
}
current = parent.?;
}
}
}

View file

@ -220,6 +220,36 @@ pub const Client = struct {
try builder.send(stream);
}
pub fn sendQueryJobByCommit(self: *Client, job_name: []const u8, commit_id: []const u8, api_key_hash: []const u8) !void {
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);
try validateCommitId(commit_id);
try validateJobName(job_name);
// Build binary message:
// [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] [commit_id: 20 bytes]
const total_len = 1 + 16 + 1 + job_name.len + 20;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.query_job);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @intCast(job_name.len);
offset += 1;
@memcpy(buffer[offset .. offset + job_name.len], job_name);
offset += job_name.len;
@memcpy(buffer[offset .. offset + 20], commit_id);
try frame.sendWebSocketFrame(stream, buffer);
}
pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void {
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);

View file

@ -22,6 +22,9 @@ pub const Opcode = enum(u8) {
validate_request = 0x16,
// Job query opcode
query_job = 0x23,
// Logs and debug opcodes
get_logs = 0x20,
stream_logs = 0x21,
@ -68,6 +71,7 @@ pub const restore_jupyter = Opcode.restore_jupyter;
pub const list_jupyter = Opcode.list_jupyter;
pub const list_jupyter_packages = Opcode.list_jupyter_packages;
pub const validate_request = Opcode.validate_request;
pub const query_job = Opcode.query_job;
pub const get_logs = Opcode.get_logs;
pub const stream_logs = Opcode.stream_logs;
pub const attach_debug = Opcode.attach_debug;

View file

@ -1,4 +1,5 @@
const std = @import("std");
const ignore = @import("ignore.zig");
pub fn encodeHexLower(allocator: std.mem.Allocator, bytes: []const u8) ![]u8 {
const hex = try allocator.alloc(u8, bytes.len * 2);
@ -105,13 +106,20 @@ pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths:
return encodeHexLower(allocator, &hash);
}
/// Calculate commit ID for a directory (SHA256 of tree state)
/// Calculate commit ID for a directory (SHA256 of tree state, respecting .gitignore)
pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
defer dir.close();
// Load .gitignore and .mlignore patterns
var gitignore = ignore.GitIgnore.init(allocator);
defer gitignore.deinit();
try gitignore.loadFromDir(dir_path, ".gitignore");
try gitignore.loadFromDir(dir_path, ".mlignore");
var walker = try dir.walk(allocator);
defer walker.deinit(allocator);
@ -124,6 +132,12 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
while (try walker.next()) |entry| {
if (entry.kind == .file) {
// Skip files matching default ignores
if (ignore.matchesDefaultIgnore(entry.path)) continue;
// Skip files matching .gitignore/.mlignore patterns
if (gitignore.isIgnored(entry.path, false)) continue;
try paths.append(allocator, try allocator.dupe(u8, entry.path));
}
}

View file

@ -0,0 +1,333 @@
const std = @import("std");
const crypto = @import("crypto.zig");
const json = @import("json.zig");
/// Cache entry for a single file
const CacheEntry = struct {
mtime: i64,
hash: []const u8,
pub fn deinit(self: *const CacheEntry, allocator: std.mem.Allocator) void {
allocator.free(self.hash);
}
};
/// Hash cache that stores file mtimes and hashes to avoid re-hashing unchanged files
pub const HashCache = struct {
entries: std.StringHashMap(CacheEntry),
allocator: std.mem.Allocator,
cache_path: []const u8,
dirty: bool,
pub fn init(allocator: std.mem.Allocator) HashCache {
return .{
.entries = std.StringHashMap(CacheEntry).init(allocator),
.allocator = allocator,
.cache_path = "",
.dirty = false,
};
}
pub fn deinit(self: *HashCache) void {
var it = self.entries.iterator();
while (it.next()) |entry| {
entry.value_ptr.deinit(self.allocator);
self.allocator.free(entry.key_ptr.*);
}
self.entries.deinit();
if (self.cache_path.len > 0) {
self.allocator.free(self.cache_path);
}
}
/// Get default cache path: ~/.ml/cache/hashes.json
pub fn getDefaultPath(allocator: std.mem.Allocator) ![]const u8 {
const home = std.posix.getenv("HOME") orelse {
return error.NoHomeDirectory;
};
// Ensure cache directory exists
const cache_dir = try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache" });
defer allocator.free(cache_dir);
std.fs.cwd().makeDir(cache_dir) catch |err| switch (err) {
error.PathAlreadyExists => {},
else => return err,
};
return try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache", "hashes.json" });
}
/// Load cache from disk
pub fn load(self: *HashCache) !void {
const cache_path = try getDefaultPath(self.allocator);
self.cache_path = cache_path;
const file = std.fs.cwd().openFile(cache_path, .{}) catch |err| switch (err) {
error.FileNotFound => return, // No cache yet is fine
else => return err,
};
defer file.close();
const content = try file.readToEndAlloc(self.allocator, 10 * 1024 * 1024); // Max 10MB
defer self.allocator.free(content);
// Parse JSON
const parsed = try std.json.parseFromSlice(std.json.Value, self.allocator, content, .{});
defer parsed.deinit();
const root = parsed.value.object;
const version = root.get("version") orelse return error.InvalidCacheFormat;
if (version.integer != 1) return error.UnsupportedCacheVersion;
const files = root.get("files") orelse return error.InvalidCacheFormat;
if (files.object.count() == 0) return;
var it = files.object.iterator();
while (it.next()) |entry| {
const path = try self.allocator.dupe(u8, entry.key_ptr.*);
const file_obj = entry.value_ptr.object;
const mtime = file_obj.get("mtime") orelse continue;
const hash_val = file_obj.get("hash") orelse continue;
const hash = try self.allocator.dupe(u8, hash_val.string);
try self.entries.put(path, .{
.mtime = mtime.integer,
.hash = hash,
});
}
}
/// Save cache to disk
pub fn save(self: *HashCache) !void {
if (!self.dirty) return;
var json_str = std.ArrayList(u8).init(self.allocator);
defer json_str.deinit();
var writer = json_str.writer();
// Write header
try writer.print("{{\n \"version\": 1,\n \"files\": {{\n", .{});
// Write entries
var it = self.entries.iterator();
var first = true;
while (it.next()) |entry| {
if (!first) try writer.print(",\n", .{});
first = false;
// Escape path for JSON
const escaped_path = try json.escapeString(self.allocator, entry.key_ptr.*);
defer self.allocator.free(escaped_path);
try writer.print(" \"{s}\": {{\"mtime\": {d}, \"hash\": \"{s}\"}}", .{
escaped_path,
entry.value_ptr.mtime,
entry.value_ptr.hash,
});
}
// Write footer
try writer.print("\n }}\n}}\n", .{});
// Write atomically
const tmp_path = try std.fmt.allocPrint(self.allocator, "{s}.tmp", .{self.cache_path});
defer self.allocator.free(tmp_path);
{
const file = try std.fs.cwd().createFile(tmp_path, .{});
defer file.close();
try file.writeAll(json_str.items);
}
try std.fs.cwd().rename(tmp_path, self.cache_path);
self.dirty = false;
}
/// Check if file needs re-hashing
pub fn needsHash(self: *HashCache, path: []const u8, mtime: i64) bool {
const entry = self.entries.get(path) orelse return true;
return entry.mtime != mtime;
}
/// Get cached hash for file
pub fn getHash(self: *HashCache, path: []const u8, mtime: i64) ?[]const u8 {
const entry = self.entries.get(path) orelse return null;
if (entry.mtime != mtime) return null;
return entry.hash;
}
/// Store hash for file
pub fn putHash(self: *HashCache, path: []const u8, mtime: i64, hash: []const u8) !void {
const path_copy = try self.allocator.dupe(u8, path);
// Remove old entry if exists
if (self.entries.fetchRemove(path_copy)) |old| {
self.allocator.free(old.key);
old.value.deinit(self.allocator);
}
const hash_copy = try self.allocator.dupe(u8, hash);
try self.entries.put(path_copy, .{
.mtime = mtime,
.hash = hash_copy,
});
self.dirty = true;
}
/// Clear cache (e.g., after git checkout)
pub fn clear(self: *HashCache) void {
var it = self.entries.iterator();
while (it.next()) |entry| {
entry.value_ptr.deinit(self.allocator);
self.allocator.free(entry.key_ptr.*);
}
self.entries.clearRetainingCapacity();
self.dirty = true;
}
/// Get cache stats
pub fn getStats(self: *HashCache) struct { entries: usize, dirty: bool } {
return .{
.entries = self.entries.count(),
.dirty = self.dirty,
};
}
};
/// Calculate directory hash with cache support
pub fn hashDirectoryWithCache(
allocator: std.mem.Allocator,
dir_path: []const u8,
cache: *HashCache,
) ![]const u8 {
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
defer dir.close();
// Load .gitignore patterns
var gitignore = @import("ignore.zig").GitIgnore.init(allocator);
defer gitignore.deinit();
try gitignore.loadFromDir(dir_path, ".gitignore");
try gitignore.loadFromDir(dir_path, ".mlignore");
var walker = try dir.walk(allocator);
defer walker.deinit(allocator);
// Collect paths and check cache
var paths: std.ArrayList(struct { path: []const u8, mtime: i64, use_cache: bool }) = .{};
defer {
for (paths.items) |p| allocator.free(p.path);
paths.deinit(allocator);
}
while (try walker.next()) |entry| {
if (entry.kind == .file) {
// Skip files matching default ignores
if (@import("ignore.zig").matchesDefaultIgnore(entry.path)) continue;
// Skip files matching .gitignore/.mlignore patterns
if (gitignore.isIgnored(entry.path, false)) continue;
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, entry.path });
defer allocator.free(full_path);
const stat = dir.statFile(entry.path) catch |err| switch (err) {
error.FileNotFound => continue,
else => return err,
};
const mtime = @as(i64, @intCast(stat.mtime));
const use_cache = !cache.needsHash(entry.path, mtime);
try paths.append(.{
.path = try allocator.dupe(u8, entry.path),
.mtime = mtime,
.use_cache = use_cache,
});
}
}
// Sort paths for deterministic hashing
std.sort.block(
struct { path: []const u8, mtime: i64, use_cache: bool },
paths.items,
{},
struct {
fn lessThan(_: void, a: anytype, b: anytype) bool {
return std.mem.order(u8, a.path, b.path) == .lt;
}
}.lessThan,
);
// Hash each file (using cache where possible)
for (paths.items) |item| {
hasher.update(item.path);
hasher.update(&[_]u8{0}); // Separator
const file_hash: []const u8 = if (item.use_cache)
cache.getHash(item.path, item.mtime).?
else blk: {
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, item.path });
defer allocator.free(full_path);
const hash = try crypto.hashFile(allocator, full_path);
try cache.putHash(item.path, item.mtime, hash);
break :blk hash;
};
defer if (!item.use_cache) allocator.free(file_hash);
hasher.update(file_hash);
hasher.update(&[_]u8{0}); // Separator
}
var hash: [32]u8 = undefined;
hasher.final(&hash);
return crypto.encodeHexLower(allocator, &hash);
}
test "HashCache basic operations" {
const allocator = std.testing.allocator;
var cache = HashCache.init(allocator);
defer cache.deinit();
// Put and get
try cache.putHash("src/main.py", 1708369200, "abc123");
const hash = cache.getHash("src/main.py", 1708369200);
try std.testing.expect(hash != null);
try std.testing.expectEqualStrings("abc123", hash.?);
// Wrong mtime should return null
const stale = cache.getHash("src/main.py", 1708369201);
try std.testing.expect(stale == null);
// needsHash should detect stale entries
try std.testing.expect(cache.needsHash("src/main.py", 1708369201));
try std.testing.expect(!cache.needsHash("src/main.py", 1708369200));
}
test "HashCache clear" {
const allocator = std.testing.allocator;
var cache = HashCache.init(allocator);
defer cache.deinit();
try cache.putHash("file1.py", 123, "hash1");
try cache.putHash("file2.py", 456, "hash2");
try std.testing.expectEqual(@as(usize, 2), cache.getStats().entries);
cache.clear();
try std.testing.expectEqual(@as(usize, 0), cache.getStats().entries);
try std.testing.expect(cache.getStats().dirty);
}

261
cli/src/utils/ignore.zig Normal file
View file

@ -0,0 +1,261 @@
const std = @import("std");
/// Pattern type for ignore rules
const Pattern = struct {
pattern: []const u8,
is_negation: bool, // true if pattern starts with !
is_dir_only: bool, // true if pattern ends with /
anchored: bool, // true if pattern contains / (not at start)
};
/// GitIgnore matcher for filtering files during directory traversal
pub const GitIgnore = struct {
patterns: std.ArrayList(Pattern),
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator) GitIgnore {
return .{
.patterns = std.ArrayList(Pattern).init(allocator),
.allocator = allocator,
};
}
pub fn deinit(self: *GitIgnore) void {
for (self.patterns.items) |p| {
self.allocator.free(p.pattern);
}
self.patterns.deinit();
}
/// Load .gitignore or .mlignore from directory
pub fn loadFromDir(self: *GitIgnore, dir_path: []const u8, filename: []const u8) !void {
const path = try std.fs.path.join(self.allocator, &[_][]const u8{ dir_path, filename });
defer self.allocator.free(path);
const file = std.fs.cwd().openFile(path, .{}) catch |err| switch (err) {
error.FileNotFound => return, // No ignore file is fine
else => return err,
};
defer file.close();
const content = try file.readToEndAlloc(self.allocator, 1024 * 1024); // Max 1MB
defer self.allocator.free(content);
try self.parse(content);
}
/// Parse ignore patterns from content
pub fn parse(self: *GitIgnore, content: []const u8) !void {
var lines = std.mem.split(u8, content, "\n");
while (lines.next()) |line| {
const trimmed = std.mem.trim(u8, line, " \t\r");
// Skip empty lines and comments
if (trimmed.len == 0 or std.mem.startsWith(u8, trimmed, "#")) continue;
try self.addPattern(trimmed);
}
}
/// Add a single pattern
fn addPattern(self: *GitIgnore, pattern: []const u8) !void {
var p = pattern;
var is_negation = false;
var is_dir_only = false;
// Check for negation
if (std.mem.startsWith(u8, p, "!")) {
is_negation = true;
p = p[1..];
}
// Check for directory-only marker
if (std.mem.endsWith(u8, p, "/")) {
is_dir_only = true;
p = p[0 .. p.len - 1];
}
// Remove leading slash (anchored patterns)
const anchored = std.mem.indexOf(u8, p, "/") != null;
if (std.mem.startsWith(u8, p, "/")) {
p = p[1..];
}
// Store normalized pattern
const pattern_copy = try self.allocator.dupe(u8, p);
try self.patterns.append(.{
.pattern = pattern_copy,
.is_negation = is_negation,
.is_dir_only = is_dir_only,
.anchored = anchored,
});
}
/// Check if a path should be ignored
pub fn isIgnored(self: *GitIgnore, path: []const u8, is_dir: bool) bool {
var ignored = false;
for (self.patterns.items) |pattern| {
if (self.matches(pattern, path, is_dir)) {
ignored = !pattern.is_negation;
}
}
return ignored;
}
/// Check if a single pattern matches
fn matches(self: *GitIgnore, pattern: Pattern, path: []const u8, is_dir: bool) bool {
_ = self;
// Directory-only patterns only match directories
if (pattern.is_dir_only and !is_dir) return false;
// Convert gitignore pattern to glob
if (patternMatch(pattern.pattern, path)) {
return true;
}
// Also check basename match for non-anchored patterns
if (!pattern.anchored) {
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
const basename = path[idx + 1 ..];
if (patternMatch(pattern.pattern, basename)) {
return true;
}
}
}
return false;
}
/// Simple glob pattern matching
fn patternMatch(pattern: []const u8, path: []const u8) bool {
var p_idx: usize = 0;
var s_idx: usize = 0;
while (p_idx < pattern.len) {
const p_char = pattern[p_idx];
if (p_char == '*') {
// Handle ** (matches any number of directories)
if (p_idx + 1 < pattern.len and pattern[p_idx + 1] == '*') {
// ** matches everything
return true;
}
// Single * matches anything until next / or end
p_idx += 1;
if (p_idx >= pattern.len) {
// * at end - match rest of path
return true;
}
const next_char = pattern[p_idx];
while (s_idx < path.len and path[s_idx] != next_char) {
s_idx += 1;
}
} else if (p_char == '?') {
// ? matches single character
if (s_idx >= path.len) return false;
p_idx += 1;
s_idx += 1;
} else {
// Literal character match
if (s_idx >= path.len or path[s_idx] != p_char) return false;
p_idx += 1;
s_idx += 1;
}
}
return s_idx == path.len;
}
};
/// Default patterns always ignored (like git does)
pub const DEFAULT_IGNORES = [_][]const u8{
".git",
".ml",
"__pycache__",
"*.pyc",
"*.pyo",
".DS_Store",
"node_modules",
".venv",
"venv",
".env",
".idea",
".vscode",
"*.log",
"*.tmp",
"*.swp",
"*.swo",
"*~",
};
/// Check if path matches default ignores
pub fn matchesDefaultIgnore(path: []const u8) bool {
// Check exact matches
for (DEFAULT_IGNORES) |pattern| {
if (std.mem.eql(u8, path, pattern)) return true;
}
// Check suffix matches for patterns like *.pyc
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
const basename = path[idx + 1 ..];
for (DEFAULT_IGNORES) |pattern| {
if (std.mem.startsWith(u8, pattern, "*.")) {
const ext = pattern[1..]; // Get extension including dot
if (std.mem.endsWith(u8, basename, ext)) return true;
}
}
}
return false;
}
test "GitIgnore basic patterns" {
const allocator = std.testing.allocator;
var gi = GitIgnore.init(allocator);
defer gi.deinit();
try gi.parse("node_modules\n__pycache__\n*.pyc\n");
try std.testing.expect(gi.isIgnored("node_modules", true));
try std.testing.expect(gi.isIgnored("__pycache__", true));
try std.testing.expect(gi.isIgnored("test.pyc", false));
try std.testing.expect(!gi.isIgnored("main.py", false));
}
test "GitIgnore negation" {
const allocator = std.testing.allocator;
var gi = GitIgnore.init(allocator);
defer gi.deinit();
try gi.parse("*.log\n!important.log\n");
try std.testing.expect(gi.isIgnored("debug.log", false));
try std.testing.expect(!gi.isIgnored("important.log", false));
}
test "GitIgnore directory-only" {
const allocator = std.testing.allocator;
var gi = GitIgnore.init(allocator);
defer gi.deinit();
try gi.parse("build/\n");
try std.testing.expect(gi.isIgnored("build", true));
try std.testing.expect(!gi.isIgnored("build", false));
}
test "matchesDefaultIgnore" {
try std.testing.expect(matchesDefaultIgnore(".git"));
try std.testing.expect(matchesDefaultIgnore("__pycache__"));
try std.testing.expect(matchesDefaultIgnore("node_modules"));
try std.testing.expect(matchesDefaultIgnore("test.pyc"));
try std.testing.expect(!matchesDefaultIgnore("main.py"));
}

View file

@ -0,0 +1,195 @@
const std = @import("std");
const c = @cImport({
@cInclude("dataset_hash.h");
});
/// Native hash context for high-performance file hashing
pub const NativeHasher = struct {
ctx: *c.fh_context_t,
allocator: std.mem.Allocator,
/// Initialize native hasher with thread pool
/// num_threads: 0 = auto-detect (use hardware concurrency)
pub fn init(allocator: std.mem.Allocator, num_threads: u32) !NativeHasher {
const ctx = c.fh_init(num_threads);
if (ctx == null) return error.NativeInitFailed;
return .{
.ctx = ctx,
.allocator = allocator,
};
}
/// Cleanup native hasher and thread pool
pub fn deinit(self: *NativeHasher) void {
c.fh_cleanup(self.ctx);
}
/// Hash a single file
pub fn hashFile(self: *NativeHasher, path: []const u8) ![]const u8 {
const c_path = try self.allocator.dupeZ(u8, path);
defer self.allocator.free(c_path);
const result = c.fh_hash_file(self.ctx, c_path.ptr);
if (result == null) return error.HashFailed;
defer c.fh_free_string(result);
return try self.allocator.dupe(u8, std.mem.span(result));
}
/// Batch hash multiple files (amortizes CGo overhead)
pub fn hashBatch(self: *NativeHasher, paths: []const []const u8) ![][]const u8 {
// Convert paths to C string array
const c_paths = try self.allocator.alloc([*c]const u8, paths.len);
defer self.allocator.free(c_paths);
for (paths, 0..) |path, i| {
const c_path = try self.allocator.dupeZ(u8, path);
c_paths[i] = c_path.ptr;
// Note: we need to keep these alive until after fh_hash_batch
}
defer {
for (c_paths) |p| {
self.allocator.free(std.mem.span(p));
}
}
// Allocate results array
const results = try self.allocator.alloc([*c]u8, paths.len);
defer self.allocator.free(results);
// Call native batch hash
const ret = c.fh_hash_batch(self.ctx, c_paths.ptr, @intCast(paths.len), results.ptr);
if (ret != 0) return error.HashFailed;
// Convert results to Zig strings
var hashes = try self.allocator.alloc([]const u8, paths.len);
errdefer {
for (hashes) |h| self.allocator.free(h);
self.allocator.free(hashes);
}
for (results, 0..) |r, i| {
hashes[i] = try self.allocator.dupe(u8, std.mem.span(r));
c.fh_free_string(r);
}
return hashes;
}
/// Hash entire directory (combined hash)
pub fn hashDirectory(self: *NativeHasher, dir_path: []const u8) ![]const u8 {
const c_path = try self.allocator.dupeZ(u8, dir_path);
defer self.allocator.free(c_path);
const result = c.fh_hash_directory(self.ctx, c_path.ptr);
if (result == null) return error.HashFailed;
defer c.fh_free_string(result);
return try self.allocator.dupe(u8, std.mem.span(result));
}
/// Hash directory with batch output (individual file hashes)
pub fn hashDirectoryBatch(
self: *NativeHasher,
dir_path: []const u8,
max_results: u32,
) !struct { hashes: [][]const u8, paths: [][]const u8, count: u32 } {
const c_path = try self.allocator.dupeZ(u8, dir_path);
defer self.allocator.free(c_path);
// Allocate output arrays
const hashes = try self.allocator.alloc([*c]u8, max_results);
defer self.allocator.free(hashes);
const paths = try self.allocator.alloc([*c]u8, max_results);
defer self.allocator.free(paths);
var count: u32 = 0;
const ret = c.fh_hash_directory_batch(
self.ctx,
c_path.ptr,
hashes.ptr,
paths.ptr,
max_results,
&count,
);
if (ret != 0) return error.HashFailed;
// Convert to Zig arrays
var zig_hashes = try self.allocator.alloc([]const u8, count);
errdefer {
for (zig_hashes) |h| self.allocator.free(h);
self.allocator.free(zig_hashes);
}
var zig_paths = try self.allocator.alloc([]const u8, count);
errdefer {
for (zig_paths) |p| self.allocator.free(p);
self.allocator.free(zig_paths);
}
for (0..count) |i| {
zig_hashes[i] = try self.allocator.dupe(u8, std.mem.span(hashes[i]));
c.fh_free_string(hashes[i]);
zig_paths[i] = try self.allocator.dupe(u8, std.mem.span(paths[i]));
c.fh_free_string(paths[i]);
}
return .{
.hashes = zig_hashes,
.paths = zig_paths,
.count = count,
};
}
/// Check if SIMD SHA-256 is available
pub fn hasSimd(self: *NativeHasher) bool {
_ = self;
return c.fh_has_simd_sha256() != 0;
}
/// Get implementation info (SIMD type, etc.)
pub fn getImplInfo(self: *NativeHasher) []const u8 {
_ = self;
return std.mem.span(c.fh_get_simd_impl_name());
}
};
/// Convenience function: hash directory using native library
pub fn hashDirectoryNative(allocator: std.mem.Allocator, dir_path: []const u8) ![]const u8 {
var hasher = try NativeHasher.init(allocator, 0); // Auto-detect threads
defer hasher.deinit();
return try hasher.hashDirectory(dir_path);
}
/// Convenience function: batch hash files using native library
pub fn hashFilesNative(
allocator: std.mem.Allocator,
paths: []const []const u8,
) ![][]const u8 {
var hasher = try NativeHasher.init(allocator, 0);
defer hasher.deinit();
return try hasher.hashBatch(paths);
}
test "NativeHasher basic operations" {
const allocator = std.testing.allocator;
// Skip if native library not available
var hasher = NativeHasher.init(allocator, 1) catch |err| {
if (err == error.NativeInitFailed) {
std.debug.print("Native library not available, skipping test\n", .{});
return;
}
return err;
};
defer hasher.deinit();
// Check SIMD availability
const has_simd = hasher.hasSimd();
const impl_name = hasher.getImplInfo();
std.debug.print("SIMD: {any}, Impl: {s}\n", .{ has_simd, impl_name });
}

View file

@ -0,0 +1,231 @@
const std = @import("std");
const ignore = @import("ignore.zig");
/// Thread-safe work queue for parallel directory walking
const WorkQueue = struct {
items: std.ArrayList(WorkItem),
mutex: std.Thread.Mutex,
condition: std.Thread.Condition,
done: bool,
const WorkItem = struct {
path: []const u8,
depth: usize,
};
fn init(allocator: std.mem.Allocator) WorkQueue {
return .{
.items = std.ArrayList(WorkItem).init(allocator),
.mutex = .{},
.condition = .{},
.done = false,
};
}
fn deinit(self: *WorkQueue, allocator: std.mem.Allocator) void {
for (self.items.items) |item| {
allocator.free(item.path);
}
self.items.deinit();
}
fn push(self: *WorkQueue, path: []const u8, depth: usize, allocator: std.mem.Allocator) !void {
self.mutex.lock();
defer self.mutex.unlock();
try self.items.append(.{
.path = try allocator.dupe(u8, path),
.depth = depth,
});
self.condition.signal();
}
fn pop(self: *WorkQueue) ?WorkItem {
self.mutex.lock();
defer self.mutex.unlock();
while (self.items.items.len == 0 and !self.done) {
self.condition.wait(&self.mutex);
}
if (self.items.items.len == 0) return null;
return self.items.pop();
}
fn setDone(self: *WorkQueue) void {
self.mutex.lock();
defer self.mutex.unlock();
self.done = true;
self.condition.broadcast();
}
};
/// Result from parallel directory walk
const WalkResult = struct {
files: std.ArrayList([]const u8),
mutex: std.Thread.Mutex,
fn init(allocator: std.mem.Allocator) WalkResult {
return .{
.files = std.ArrayList([]const u8).init(allocator),
.mutex = .{},
};
}
fn deinit(self: *WalkResult, allocator: std.mem.Allocator) void {
for (self.files.items) |file| {
allocator.free(file);
}
self.files.deinit();
}
fn add(self: *WalkResult, path: []const u8, allocator: std.mem.Allocator) !void {
self.mutex.lock();
defer self.mutex.unlock();
try self.files.append(try allocator.dupe(u8, path));
}
};
/// Thread context for parallel walking
const ThreadContext = struct {
queue: *WorkQueue,
result: *WalkResult,
gitignore: *ignore.GitIgnore,
base_path: []const u8,
allocator: std.mem.Allocator,
max_depth: usize,
};
/// Worker thread function for parallel directory walking
fn walkWorker(ctx: *ThreadContext) void {
while (true) {
const item = ctx.queue.pop() orelse break;
defer ctx.allocator.free(item.path);
if (item.depth >= ctx.max_depth) continue;
walkDirectoryParallel(ctx, item.path, item.depth) catch |err| {
std.log.warn("Error walking {s}: {any}", .{ item.path, err });
};
}
}
/// Walk a single directory and add subdirectories to queue
fn walkDirectoryParallel(ctx: *ThreadContext, dir_path: []const u8, depth: usize) !void {
const full_path = if (std.mem.eql(u8, dir_path, "."))
ctx.base_path
else
try std.fs.path.join(ctx.allocator, &[_][]const u8{ ctx.base_path, dir_path });
defer if (!std.mem.eql(u8, dir_path, ".")) ctx.allocator.free(full_path);
var dir = std.fs.cwd().openDir(full_path, .{ .iterate = true }) catch |err| switch (err) {
error.AccessDenied => return,
error.FileNotFound => return,
else => return err,
};
defer dir.close();
var it = dir.iterate();
while (true) {
const entry = it.next() catch |err| switch (err) {
error.AccessDenied => continue,
else => return err,
} orelse break;
const entry_path = if (std.mem.eql(u8, dir_path, "."))
try std.fmt.allocPrint(ctx.allocator, "{s}", .{entry.name})
else
try std.fs.path.join(ctx.allocator, &[_][]const u8{ dir_path, entry.name });
defer ctx.allocator.free(entry_path);
// Check default ignores
if (ignore.matchesDefaultIgnore(entry_path)) continue;
// Check gitignore patterns
const is_dir = entry.kind == .directory;
if (ctx.gitignore.isIgnored(entry_path, is_dir)) continue;
if (is_dir) {
// Add subdirectory to work queue
ctx.queue.push(entry_path, depth + 1, ctx.allocator) catch |err| {
std.log.warn("Failed to queue {s}: {any}", .{ entry_path, err });
};
} else if (entry.kind == .file) {
// Add file to results
ctx.result.add(entry_path, ctx.allocator) catch |err| {
std.log.warn("Failed to add file {s}: {any}", .{ entry_path, err });
};
}
}
}
/// Parallel directory walker that uses multiple threads
pub fn parallelWalk(
allocator: std.mem.Allocator,
base_path: []const u8,
gitignore: *ignore.GitIgnore,
num_threads: usize,
) ![][]const u8 {
var queue = WorkQueue.init(allocator);
defer queue.deinit(allocator);
var result = WalkResult.init(allocator);
defer result.deinit(allocator);
// Start with base directory
try queue.push(".", 0, allocator);
// Create thread context
var ctx = ThreadContext{
.queue = &queue,
.result = &result,
.gitignore = gitignore,
.base_path = base_path,
.allocator = allocator,
.max_depth = 100, // Prevent infinite recursion
};
// Spawn worker threads
var threads = try allocator.alloc(std.Thread, num_threads);
defer allocator.free(threads);
for (0..num_threads) |i| {
threads[i] = try std.Thread.spawn(.{}, walkWorker, .{&ctx});
}
// Wait for all workers to complete
for (threads) |thread| {
thread.join();
}
// Sort results for deterministic ordering
std.sort.block([]const u8, result.files.items, {}, struct {
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
return std.mem.order(u8, a, b) == .lt;
}
}.lessThan);
// Transfer ownership to caller
const files = try allocator.alloc([]const u8, result.files.items.len);
@memcpy(files, result.files.items);
result.files.items.len = 0; // Prevent deinit from freeing
return files;
}
test "parallelWalk basic" {
const allocator = std.testing.allocator;
var gitignore = ignore.GitIgnore.init(allocator);
defer gitignore.deinit();
// Walk the current directory with 4 threads
const files = try parallelWalk(allocator, ".", &gitignore, 4);
defer {
for (files) |f| allocator.free(f);
allocator.free(files);
}
// Should find at least some files
try std.testing.expect(files.len > 0);
}

283
cli/src/utils/watch.zig Normal file
View file

@ -0,0 +1,283 @@
const std = @import("std");
const os = std.os;
const log = std.log;
/// File watcher using OS-native APIs (kqueue on macOS, inotify on Linux)
/// Zero third-party dependencies - uses only standard library OS bindings
pub const FileWatcher = struct {
allocator: std.mem.Allocator,
watched_paths: std.StringHashMap(void),
// Platform-specific handles
kqueue_fd: if (@import("builtin").target.os.tag == .macos) i32 else void,
inotify_fd: if (@import("builtin").target.os.tag == .linux) i32 else void,
// Debounce timer
last_event_time: i64,
debounce_ms: i64,
pub fn init(allocator: std.mem.Allocator, debounce_ms: i64) !FileWatcher {
const target = @import("builtin").target;
var watcher = FileWatcher{
.allocator = allocator,
.watched_paths = std.StringHashMap(void).init(allocator),
.kqueue_fd = if (target.os.tag == .macos) -1 else {},
.inotify_fd = if (target.os.tag == .linux) -1 else {},
.last_event_time = 0,
.debounce_ms = debounce_ms,
};
switch (target.os.tag) {
.macos => {
watcher.kqueue_fd = try os.kqueue();
},
.linux => {
watcher.inotify_fd = try std.os.inotify_init1(os.linux.IN_CLOEXEC);
},
else => {
return error.UnsupportedPlatform;
},
}
return watcher;
}
pub fn deinit(self: *FileWatcher) void {
const target = @import("builtin").target;
switch (target.os.tag) {
.macos => {
if (self.kqueue_fd != -1) {
os.close(self.kqueue_fd);
}
},
.linux => {
if (self.inotify_fd != -1) {
os.close(self.inotify_fd);
}
},
else => {},
}
var it = self.watched_paths.keyIterator();
while (it.next()) |key| {
self.allocator.free(key.*);
}
self.watched_paths.deinit();
}
/// Add a directory to watch recursively
pub fn watchDirectory(self: *FileWatcher, path: []const u8) !void {
if (self.watched_paths.contains(path)) return;
const path_copy = try self.allocator.dupe(u8, path);
try self.watched_paths.put(path_copy, {});
const target = @import("builtin").target;
switch (target.os.tag) {
.macos => try self.addKqueueWatch(path),
.linux => try self.addInotifyWatch(path),
else => return error.UnsupportedPlatform,
}
// Recursively watch subdirectories
var dir = std.fs.cwd().openDir(path, .{ .iterate = true }) catch |err| switch (err) {
error.AccessDenied => return,
error.FileNotFound => return,
else => return err,
};
defer dir.close();
var it = dir.iterate();
while (true) {
const entry = it.next() catch |err| switch (err) {
error.AccessDenied => continue,
else => return err,
} orelse break;
if (entry.kind == .directory) {
const subpath = try std.fs.path.join(self.allocator, &[_][]const u8{ path, entry.name });
defer self.allocator.free(subpath);
try self.watchDirectory(subpath);
}
}
}
/// Add kqueue watch for macOS
fn addKqueueWatch(self: *FileWatcher, path: []const u8) !void {
if (@import("builtin").target.os.tag != .macos) return;
const fd = try os.open(path, os.O.EVTONLY | os.O_RDONLY, 0);
defer os.close(fd);
const event = os.Kevent{
.ident = @intCast(fd),
.filter = os.EVFILT_VNODE,
.flags = os.EV_ADD | os.EV_CLEAR,
.fflags = os.NOTE_WRITE | os.NOTE_EXTEND | os.NOTE_ATTRIB | os.NOTE_LINK | os.NOTE_RENAME | os.NOTE_REVOKE,
.data = 0,
.udata = 0,
};
const changes = [_]os.Kevent{event};
_ = try os.kevent(self.kqueue_fd, &changes, &.{}, null);
}
/// Add inotify watch for Linux
fn addInotifyWatch(self: *FileWatcher, path: []const u8) !void {
if (@import("builtin").target.os.tag != .linux) return;
const mask = os.linux.IN_MODIFY | os.linux.IN_CREATE | os.linux.IN_DELETE |
os.linux.IN_MOVED_FROM | os.linux.IN_MOVED_TO | os.linux.IN_ATTRIB;
const wd = try os.linux.inotify_add_watch(self.inotify_fd, path.ptr, mask);
if (wd < 0) return error.InotifyError;
}
/// Wait for file changes with debouncing
pub fn waitForChanges(self: *FileWatcher, timeout_ms: i32) !bool {
const target = @import("builtin").target;
switch (target.os.tag) {
.macos => return try self.waitKqueue(timeout_ms),
.linux => return try self.waitInotify(timeout_ms),
else => return error.UnsupportedPlatform,
}
}
/// kqueue wait implementation
fn waitKqueue(self: *FileWatcher, timeout_ms: i32) !bool {
if (@import("builtin").target.os.tag != .macos) return false;
var ts: os.timespec = undefined;
if (timeout_ms >= 0) {
ts.tv_sec = @divTrunc(timeout_ms, 1000);
ts.tv_nsec = @mod(timeout_ms, 1000) * 1000000;
}
var events: [10]os.Kevent = undefined;
const nev = os.kevent(self.kqueue_fd, &.{}, &events, if (timeout_ms >= 0) &ts else null) catch |err| switch (err) {
error.ETIME => return false, // Timeout
else => return err,
};
if (nev > 0) {
const now = std.time.milliTimestamp();
if (now - self.last_event_time > self.debounce_ms) {
self.last_event_time = now;
return true;
}
}
return false;
}
/// inotify wait implementation
fn waitInotify(self: *FileWatcher, timeout_ms: i32) !bool {
if (@import("builtin").target.os.tag != .linux) return false;
var fds = [_]os.pollfd{.{
.fd = self.inotify_fd,
.events = os.POLLIN,
.revents = 0,
}};
const ready = os.poll(&fds, timeout_ms) catch |err| switch (err) {
error.ETIME => return false,
else => return err,
};
if (ready > 0 and (fds[0].revents & os.POLLIN) != 0) {
var buf: [4096]u8 align(@alignOf(os.linux.inotify_event)) = undefined;
const bytes_read = try os.read(self.inotify_fd, &buf);
if (bytes_read > 0) {
const now = std.time.milliTimestamp();
if (now - self.last_event_time > self.debounce_ms) {
self.last_event_time = now;
return true;
}
}
}
return false;
}
/// Run watch loop with callback
pub fn run(self: *FileWatcher, callback: fn () void) !void {
log.info("Watching for file changes (debounce: {d}ms)...", .{self.debounce_ms});
while (true) {
if (try self.waitForChanges(-1)) {
callback();
}
}
}
};
/// Watch command handler
pub fn watchCommand(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len < 1) {
std.debug.print("Usage: ml watch <path> [--sync] [--queue]\n", .{});
return error.InvalidArgs;
}
const path = args[0];
var auto_sync = false;
var auto_queue = false;
// Parse flags
for (args[1..]) |arg| {
if (std.mem.eql(u8, arg, "--sync")) auto_sync = true;
if (std.mem.eql(u8, arg, "--queue")) auto_queue = true;
}
var watcher = try FileWatcher.init(allocator, 100); // 100ms debounce
defer watcher.deinit();
try watcher.watchDirectory(path);
log.info("Watching {s} for changes...", .{path});
// Callback for file changes
const CallbackContext = struct {
allocator: std.mem.Allocator,
path: []const u8,
auto_sync: bool,
auto_queue: bool,
};
const ctx = CallbackContext{
.allocator = allocator,
.path = path,
.auto_sync = auto_sync,
.auto_queue = auto_queue,
};
// Run watch loop
while (true) {
if (try watcher.waitForChanges(-1)) {
log.info("File changes detected", .{});
if (auto_sync) {
log.info("Auto-syncing...", .{});
// Trigger sync (implementation would call sync command)
_ = ctx;
}
if (auto_queue) {
log.info("Auto-queuing...", .{});
// Trigger queue (implementation would call queue command)
}
}
}
}
test "FileWatcher init/deinit" {
const allocator = std.testing.allocator;
var watcher = try FileWatcher.init(allocator, 100);
defer watcher.deinit();
}

View file

@ -356,7 +356,8 @@ api_key = "your_api_key_here" # Your API key (get from admin)
// Set proper permissions
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
log.Printf("Warning: %v", err)
// Log permission warning but don't fail
_ = err
}
return nil

View file

@ -17,12 +17,14 @@ type Config struct {
SSHKey string `toml:"ssh_key"`
Port int `toml:"port"`
BasePath string `toml:"base_path"`
Mode string `toml:"mode"` // "dev" or "prod"
WrapperScript string `toml:"wrapper_script"`
TrainScript string `toml:"train_script"`
RedisAddr string `toml:"redis_addr"`
RedisPassword string `toml:"redis_password"`
RedisDB int `toml:"redis_db"`
KnownHosts string `toml:"known_hosts"`
ServerURL string `toml:"server_url"` // WebSocket server URL (e.g., ws://localhost:8080)
// Authentication
Auth auth.Config `toml:"auth"`
@ -133,6 +135,20 @@ func (c *Config) Validate() error {
}
}
// Set default mode if not specified
if c.Mode == "" {
if os.Getenv("FETCH_ML_TUI_MODE") != "" {
c.Mode = os.Getenv("FETCH_ML_TUI_MODE")
} else {
c.Mode = "dev" // Default to dev mode
}
}
// Set mode-appropriate default paths using project-relative paths
if c.BasePath == "" {
c.BasePath = utils.ModeBasedBasePath(c.Mode)
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = utils.ExpandPath(c.BasePath)

View file

@ -24,6 +24,7 @@ func (c *Controller) loadAllData() tea.Cmd {
c.loadQueue(),
c.loadGPU(),
c.loadContainer(),
c.loadDatasets(),
)
}
@ -39,6 +40,13 @@ func (c *Controller) loadJobs() tea.Cmd {
var jobs []model.Job
statusChan := make(chan []model.Job, 4)
// Debug: Print paths being used
c.logger.Info("Loading jobs from paths",
"pending", c.getPathForStatus(model.StatusPending),
"running", c.getPathForStatus(model.StatusRunning),
"finished", c.getPathForStatus(model.StatusFinished),
"failed", c.getPathForStatus(model.StatusFailed))
for _, status := range []model.JobStatus{
model.StatusPending,
model.StatusRunning,
@ -48,22 +56,18 @@ func (c *Controller) loadJobs() tea.Cmd {
go func(s model.JobStatus) {
path := c.getPathForStatus(s)
names := c.server.ListDir(path)
// Debug: Log what we found
c.logger.Info("Listed directory", "status", s, "path", path, "count", len(names))
var statusJobs []model.Job
for _, name := range names {
jobStatus, _ := c.taskQueue.GetJobStatus(name)
taskID := jobStatus["task_id"]
priority := int64(0)
if p, ok := jobStatus["priority"]; ok {
_, err := fmt.Sscanf(p, "%d", &priority)
if err != nil {
priority = 0
}
}
// Lazy loading: only fetch basic info for list view
// Full details (GPU, narrative) loaded on selection
statusJobs = append(statusJobs, model.Job{
Name: name,
Status: s,
TaskID: taskID,
Priority: priority,
// TaskID, Priority, GPU info loaded lazily
})
}
statusChan <- statusJobs
@ -85,6 +89,24 @@ func (c *Controller) loadJobs() tea.Cmd {
}
}
// loadJobDetails loads full details for a specific job (lazy loading)
func (c *Controller) loadJobDetails(jobName string) tea.Cmd {
return func() tea.Msg {
jobStatus, _ := c.taskQueue.GetJobStatus(jobName)
// Parse priority
priority := int64(0)
if p, ok := jobStatus["priority"]; ok {
fmt.Sscanf(p, "%d", &priority)
}
// Build full job with details
// This is called when job is selected for detailed view
return model.StatusMsg{Text: "Loaded details for " + jobName, Level: "info"}
}
}
func (c *Controller) loadQueue() tea.Cmd {
return func() tea.Msg {
tasks, err := c.taskQueue.GetQueuedTasks()
@ -362,6 +384,18 @@ func (c *Controller) showQueue(m model.State) tea.Cmd {
}
}
func (c *Controller) loadDatasets() tea.Cmd {
return func() tea.Msg {
datasets, err := c.taskQueue.ListDatasets()
if err != nil {
c.logger.Error("failed to load datasets", "error", err)
return model.StatusMsg{Text: "Failed to load datasets: " + err.Error(), Level: "error"}
}
c.logger.Info("loaded datasets", "count", len(datasets))
return model.DatasetsLoadedMsg(datasets)
}
}
func tickCmd() tea.Cmd {
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
return model.TickMsg(t)

View file

@ -2,6 +2,7 @@ package controller
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/key"
@ -19,6 +20,7 @@ type Controller struct {
server *services.MLServer
taskQueue *services.TaskQueue
logger *logging.Logger
wsClient *services.WebSocketClient
}
func (c *Controller) handleKeyMsg(msg tea.KeyMsg, m model.State) (model.State, tea.Cmd) {
@ -143,6 +145,33 @@ func (c *Controller) handleGlobalKeys(msg tea.KeyMsg, m *model.State) []tea.Cmd
case key.Matches(msg, m.Keys.ViewExperiments):
m.ActiveView = model.ViewModeExperiments
cmds = append(cmds, c.loadExperiments())
case key.Matches(msg, m.Keys.ViewNarrative):
m.ActiveView = model.ViewModeNarrative
if job := getSelectedJob(*m); job != nil {
m.SelectedJob = *job
}
case key.Matches(msg, m.Keys.ViewTeam):
m.ActiveView = model.ViewModeTeam
case key.Matches(msg, m.Keys.ViewExperimentHistory):
m.ActiveView = model.ViewModeExperimentHistory
cmds = append(cmds, c.loadExperimentHistory())
case key.Matches(msg, m.Keys.ViewConfig):
m.ActiveView = model.ViewModeConfig
cmds = append(cmds, c.loadConfig())
case key.Matches(msg, m.Keys.ViewLogs):
m.ActiveView = model.ViewModeLogs
if job := getSelectedJob(*m); job != nil {
cmds = append(cmds, c.loadLogs(job.Name))
}
case key.Matches(msg, m.Keys.ViewExport):
if job := getSelectedJob(*m); job != nil {
cmds = append(cmds, c.exportJob(job.Name))
}
case key.Matches(msg, m.Keys.FilterTeam):
m.InputMode = true
m.Input.SetValue("@")
m.Input.Focus()
m.Status = "Filter by team member: @alice, @bob, @team-ml"
case key.Matches(msg, m.Keys.Cancel):
if job := getSelectedJob(*m); job != nil && job.TaskID != "" {
cmds = append(cmds, c.cancelTask(job.TaskID))
@ -181,8 +210,18 @@ func (c *Controller) applyWindowSize(msg tea.WindowSizeMsg, m model.State) model
m.QueueView.Height = listHeight - 4
m.SettingsView.Width = panelWidth
m.SettingsView.Height = listHeight - 4
m.NarrativeView.Width = panelWidth
m.NarrativeView.Height = listHeight - 4
m.TeamView.Width = panelWidth
m.TeamView.Height = listHeight - 4
m.ExperimentsView.Width = panelWidth
m.ExperimentsView.Height = listHeight - 4
m.ExperimentHistoryView.Width = panelWidth
m.ExperimentHistoryView.Height = listHeight - 4
m.ConfigView.Width = panelWidth
m.ConfigView.Height = listHeight - 4
m.LogsView.Width = panelWidth
m.LogsView.Height = listHeight - 4
return m
}
@ -245,7 +284,25 @@ func (c *Controller) handleStatusMsg(msg model.StatusMsg, m model.State) (model.
func (c *Controller) handleTickMsg(msg model.TickMsg, m model.State) (model.State, tea.Cmd) {
var cmds []tea.Cmd
if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading {
// Calculate actual refresh rate
now := time.Now()
if !m.LastFrameTime.IsZero() {
elapsed := now.Sub(m.LastFrameTime).Milliseconds()
if elapsed > 0 {
// Smooth the rate with simple averaging
m.RefreshRate = (m.RefreshRate*float64(m.FrameCount) + float64(elapsed)) / float64(m.FrameCount+1)
m.FrameCount++
if m.FrameCount > 100 {
m.FrameCount = 1
m.RefreshRate = float64(elapsed)
}
}
}
m.LastFrameTime = now
// 500ms refresh target for real-time updates
if time.Since(m.LastRefresh) > 500*time.Millisecond && !m.IsLoading {
m.LastRefresh = time.Now()
cmds = append(cmds, c.loadAllData())
}
@ -290,16 +347,25 @@ func New(
tq *services.TaskQueue,
logger *logging.Logger,
) *Controller {
// Create WebSocket client for real-time updates
wsClient := services.NewWebSocketClient(cfg.ServerURL, "", logger)
return &Controller{
config: cfg,
server: srv,
taskQueue: tq,
logger: logger,
wsClient: wsClient,
}
}
// Init initializes the TUI and returns initial commands
func (c *Controller) Init() tea.Cmd {
// Connect WebSocket for real-time updates
if err := c.wsClient.Connect(); err != nil {
c.logger.Error("WebSocket connection failed", "error", err)
}
return tea.Batch(
tea.SetWindowTitle("FetchML"),
c.loadAllData(),
@ -307,14 +373,17 @@ func (c *Controller) Init() tea.Cmd {
)
}
// Update handles all messages and updates the state
func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
switch typed := msg.(type) {
case tea.KeyMsg:
return c.handleKeyMsg(typed, m)
case tea.WindowSizeMsg:
// Only apply window size on first render, then keep constant
if m.Width == 0 && m.Height == 0 {
updated := c.applyWindowSize(typed, m)
return c.finalizeUpdate(msg, updated)
}
return c.finalizeUpdate(msg, m)
case model.JobsLoadedMsg:
return c.handleJobsLoadedMsg(typed, m)
case model.TasksLoadedMsg:
@ -323,8 +392,26 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
return c.handleGPUContent(typed, m)
case model.ContainerLoadedMsg:
return c.handleContainerContent(typed, m)
case model.QueueLoadedMsg:
return c.handleQueueContent(typed, m)
case model.DatasetsLoadedMsg:
// Format datasets into view content
var content strings.Builder
content.WriteString("Available Datasets\n")
content.WriteString(strings.Repeat("═", 50) + "\n\n")
if len(typed) == 0 {
content.WriteString("📭 No datasets found\n\n")
content.WriteString("Datasets will appear here when available\n")
content.WriteString("in the data directory.")
} else {
for i, ds := range typed {
content.WriteString(fmt.Sprintf("%d. 📁 %s\n", i+1, ds.Name))
content.WriteString(fmt.Sprintf(" Location: %s\n", ds.Location))
content.WriteString(fmt.Sprintf(" Size: %d bytes\n", ds.SizeBytes))
content.WriteString(fmt.Sprintf(" Last Access: %s\n\n", ds.LastAccess.Format("2006-01-02 15:04")))
}
}
m.DatasetView.SetContent(content.String())
m.DatasetView.GotoTop()
return c.finalizeUpdate(msg, m)
case model.SettingsContentMsg:
m.SettingsView.SetContent(string(typed))
return c.finalizeUpdate(msg, m)
@ -332,12 +419,36 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
m.ExperimentsView.SetContent(string(typed))
m.ExperimentsView.GotoTop()
return c.finalizeUpdate(msg, m)
case ExperimentHistoryLoadedMsg:
m.ExperimentHistoryView.SetContent(string(typed))
m.ExperimentHistoryView.GotoTop()
return c.finalizeUpdate(msg, m)
case ConfigLoadedMsg:
m.ConfigView.SetContent(string(typed))
m.ConfigView.GotoTop()
return c.finalizeUpdate(msg, m)
case LogsLoadedMsg:
m.LogsView.SetContent(string(typed))
m.LogsView.GotoTop()
return c.finalizeUpdate(msg, m)
case model.SettingsUpdateMsg:
return c.finalizeUpdate(msg, m)
case model.StatusMsg:
return c.handleStatusMsg(typed, m)
case model.TickMsg:
return c.handleTickMsg(typed, m)
case model.JobUpdateMsg:
// Handle real-time job status updates from WebSocket
m.Status = fmt.Sprintf("Job %s: %s", typed.JobName, typed.Status)
// Refresh job list to show updated status
return m, c.loadAllData()
case model.GPUUpdateMsg:
// Throttle GPU updates to 1/second (humans can't perceive faster)
if time.Since(m.LastGPUUpdate) > 1*time.Second {
m.LastGPUUpdate = time.Now()
return c.finalizeUpdate(msg, m)
}
return m, nil
default:
return c.finalizeUpdate(msg, m)
}
@ -346,6 +457,12 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
// ExperimentsLoadedMsg is sent when experiments are loaded
type ExperimentsLoadedMsg string
// ExperimentHistoryLoadedMsg is sent when experiment history is loaded
type ExperimentHistoryLoadedMsg string
// ConfigLoadedMsg is sent when config is loaded
type ConfigLoadedMsg string
func (c *Controller) loadExperiments() tea.Cmd {
return func() tea.Msg {
commitIDs, err := c.taskQueue.ListExperiments()
@ -372,3 +489,92 @@ func (c *Controller) loadExperiments() tea.Cmd {
return ExperimentsLoadedMsg(output)
}
}
func (c *Controller) loadExperimentHistory() tea.Cmd {
return func() tea.Msg {
// Placeholder - will show experiment history with annotations
return ExperimentHistoryLoadedMsg("Experiment History & Annotations\n\n" +
"This view will show:\n" +
"- Previous experiment runs\n" +
"- Annotations and notes\n" +
"- Config snapshots\n" +
"- Side-by-side comparisons\n\n" +
"(Requires API: GET /api/experiments/:id/history)")
}
}
func (c *Controller) loadConfig() tea.Cmd {
return func() tea.Msg {
// Build config diff showing changes from defaults
var output strings.Builder
output.WriteString("⚙️ Config View (Read-Only)\n\n")
output.WriteString("┌─ Changes from Defaults ─────────────────────┐\n")
changes := []string{}
if c.config.Host != "" {
changes = append(changes, fmt.Sprintf("│ Host: %s", c.config.Host))
}
if c.config.Port != 0 && c.config.Port != 22 {
changes = append(changes, fmt.Sprintf("│ Port: %d (default: 22)", c.config.Port))
}
if c.config.BasePath != "" {
changes = append(changes, fmt.Sprintf("│ Base Path: %s", c.config.BasePath))
}
if c.config.RedisAddr != "" && c.config.RedisAddr != "localhost:6379" {
changes = append(changes, fmt.Sprintf("│ Redis: %s (default: localhost:6379)", c.config.RedisAddr))
}
if c.config.ServerURL != "" {
changes = append(changes, fmt.Sprintf("│ Server: %s", c.config.ServerURL))
}
if len(changes) == 0 {
output.WriteString("│ (Using all default settings)\n")
} else {
for _, change := range changes {
output.WriteString(change + "\n")
}
}
output.WriteString("└─────────────────────────────────────────────┘\n\n")
output.WriteString("Full Configuration:\n")
output.WriteString(fmt.Sprintf(" Host: %s\n", c.config.Host))
output.WriteString(fmt.Sprintf(" Port: %d\n", c.config.Port))
output.WriteString(fmt.Sprintf(" Base Path: %s\n", c.config.BasePath))
output.WriteString(fmt.Sprintf(" Redis: %s\n", c.config.RedisAddr))
output.WriteString(fmt.Sprintf(" Server: %s\n", c.config.ServerURL))
output.WriteString(fmt.Sprintf(" User: %s\n\n", c.config.User))
output.WriteString("Use CLI to modify: ml config set <key> <value>")
return ConfigLoadedMsg(output.String())
}
}
// LogsLoadedMsg is sent when logs are loaded
type LogsLoadedMsg string
func (c *Controller) loadLogs(jobName string) tea.Cmd {
return func() tea.Msg {
// Placeholder - will stream logs from job
return LogsLoadedMsg("📜 Logs for " + jobName + "\n\n" +
"Log streaming will appear here...\n\n" +
"(Requires API: GET /api/jobs/" + jobName + "/logs?follow=true)")
}
}
// ExportCompletedMsg is sent when export is complete
type ExportCompletedMsg struct {
JobName string
Path string
}
func (c *Controller) exportJob(jobName string) tea.Cmd {
return func() tea.Msg {
// Show export in progress
return model.StatusMsg{
Text: "Exporting " + jobName + "... (anonymized)",
Level: "info",
}
}
}

View file

@ -1,7 +1,11 @@
// Package model provides TUI data structures and state management
package model
import "fmt"
import (
"fmt"
"github.com/charmbracelet/lipgloss"
)
// JobStatus represents the status of a job
type JobStatus string
@ -21,12 +25,23 @@ type Job struct {
Status JobStatus
TaskID string
Priority int64
// Narrative fields for research context
Hypothesis string
Context string
Intent string
ExpectedOutcome string
ActualOutcome string
OutcomeStatus string // validated, invalidated, inconclusive
// GPU allocation tracking
GPUDeviceID int // -1 if not assigned
GPUUtilization int // 0-100%
GPUMemoryUsed int64 // MB
}
// Title returns the job title for display
func (j Job) Title() string { return j.Name }
// Description returns a formatted description with status icon
// Description returns a formatted description with status icon and GPU info
func (j Job) Description() string {
icon := map[JobStatus]string{
StatusPending: "⏸",
@ -39,7 +54,16 @@ func (j Job) Description() string {
if j.Priority > 0 {
pri = fmt.Sprintf(" [P%d]", j.Priority)
}
return fmt.Sprintf("%s %s%s", icon, j.Status, pri)
gpu := ""
if j.GPUDeviceID >= 0 {
gpu = fmt.Sprintf(" [GPU:%d %d%%]", j.GPUDeviceID, j.GPUUtilization)
}
// Apply status color to the status text
statusStyle := lipgloss.NewStyle().Foreground(StatusColor(j.Status))
coloredStatus := statusStyle.Render(string(j.Status))
return fmt.Sprintf("%s %s%s%s", icon, coloredStatus, pri, gpu)
}
// FilterValue returns the value used for filtering

View file

@ -15,6 +15,13 @@ type KeyMap struct {
ViewDatasets key.Binding
ViewExperiments key.Binding
ViewSettings key.Binding
ViewNarrative key.Binding
ViewTeam key.Binding
ViewExperimentHistory key.Binding
ViewConfig key.Binding
ViewLogs key.Binding
ViewExport key.Binding
FilterTeam key.Binding
Cancel key.Binding
Delete key.Binding
MarkFailed key.Binding
@ -29,18 +36,25 @@ func DefaultKeys() KeyMap {
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")),
ViewQueue: key.NewBinding(key.WithKeys("q"), key.WithHelp("q", "view queue")),
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")),
ViewNarrative: key.NewBinding(key.WithKeys("n"), key.WithHelp("n", "narrative")),
ViewTeam: key.NewBinding(key.WithKeys("m"), key.WithHelp("m", "team")),
ViewExperimentHistory: key.NewBinding(key.WithKeys("e"), key.WithHelp("e", "experiment history")),
ViewConfig: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "config")),
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
ViewLogs: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "logs")),
ViewExport: key.NewBinding(key.WithKeys("E"), key.WithHelp("E", "export job")),
FilterTeam: key.NewBinding(key.WithKeys("@"), key.WithHelp("@", "filter by team")),
Cancel: key.NewBinding(key.WithKeys("x"), key.WithHelp("x", "cancel task")),
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")),
Quit: key.NewBinding(key.WithKeys("ctrl+c"), key.WithHelp("ctrl+c", "quit")),
}
}

View file

@ -9,6 +9,9 @@ type JobsLoadedMsg []Job
// TasksLoadedMsg contains loaded tasks from the queue
type TasksLoadedMsg []*Task
// DatasetsLoadedMsg contains loaded datasets
type DatasetsLoadedMsg []DatasetInfo
// GpuLoadedMsg contains GPU status information
type GpuLoadedMsg string

View file

@ -39,6 +39,11 @@ const (
ViewModeSettings // Settings view mode
ViewModeDatasets // Datasets view mode
ViewModeExperiments // Experiments view mode
ViewModeNarrative // Narrative/Outcome view mode
ViewModeTeam // Team collaboration view mode
ViewModeExperimentHistory // Experiment history view mode
ViewModeConfig // Config view mode
ViewModeLogs // Logs streaming view mode
)
// DatasetInfo represents dataset information in the TUI
@ -61,6 +66,12 @@ type State struct {
SettingsView viewport.Model
DatasetView viewport.Model
ExperimentsView viewport.Model
NarrativeView viewport.Model
TeamView viewport.Model
ExperimentHistoryView viewport.Model
ConfigView viewport.Model
LogsView viewport.Model
SelectedJob Job
Input textinput.Model
APIKeyInput textinput.Model
Status string
@ -72,6 +83,10 @@ type State struct {
Spinner spinner.Model
ActiveView ViewMode
LastRefresh time.Time
LastFrameTime time.Time
RefreshRate float64 // measured in ms
FrameCount int
LastGPUUpdate time.Time
IsLoading bool
JobStats map[JobStatus]int
APIKey string
@ -112,6 +127,11 @@ func InitialState(apiKey string) State {
SettingsView: viewport.New(0, 0),
DatasetView: viewport.New(0, 0),
ExperimentsView: viewport.New(0, 0),
NarrativeView: viewport.New(0, 0),
TeamView: viewport.New(0, 0),
ExperimentHistoryView: viewport.New(0, 0),
ConfigView: viewport.New(0, 0),
LogsView: viewport.New(0, 0),
Input: input,
APIKeyInput: apiKeyInput,
Status: "Connected",
@ -127,3 +147,27 @@ func InitialState(apiKey string) State {
Keys: DefaultKeys(),
}
}
// LogMsg represents a log line from a job
type LogMsg struct {
JobName string `json:"job_name"`
Line string `json:"line"`
Time string `json:"time"`
}
// JobUpdateMsg represents a real-time job status update via WebSocket
type JobUpdateMsg struct {
JobName string `json:"job_name"`
Status string `json:"status"`
TaskID string `json:"task_id"`
Progress int `json:"progress"`
}
// GPUUpdateMsg represents a real-time GPU status update via WebSocket
type GPUUpdateMsg struct {
DeviceID int `json:"device_id"`
Utilization int `json:"utilization"`
MemoryUsed int64 `json:"memory_used"`
MemoryTotal int64 `json:"memory_total"`
Temperature int `json:"temperature"`
}

View file

@ -6,6 +6,38 @@ import (
"github.com/charmbracelet/lipgloss"
)
// Status colors for job list items
var (
// StatusRunningColor is green for running jobs
StatusRunningColor = lipgloss.Color("#2ecc71")
// StatusPendingColor is yellow for pending jobs
StatusPendingColor = lipgloss.Color("#f1c40f")
// StatusFailedColor is red for failed jobs
StatusFailedColor = lipgloss.Color("#e74c3c")
// StatusFinishedColor is blue for completed jobs
StatusFinishedColor = lipgloss.Color("#3498db")
// StatusQueuedColor is gray for queued jobs
StatusQueuedColor = lipgloss.Color("#95a5a6")
)
// StatusColor returns the color for a job status
func StatusColor(status JobStatus) lipgloss.Color {
switch status {
case StatusRunning:
return StatusRunningColor
case StatusPending:
return StatusPendingColor
case StatusFailed:
return StatusFailedColor
case StatusFinished:
return StatusFinishedColor
case StatusQueued:
return StatusQueuedColor
default:
return lipgloss.Color("#ffffff")
}
}
// NewJobListDelegate creates a styled delegate for the job list
func NewJobListDelegate() list.DefaultDelegate {
delegate := list.NewDefaultDelegate()

View file

@ -0,0 +1,46 @@
// Package services provides TUI service clients
package services
import (
"fmt"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// ExportService handles job export functionality for TUI
type ExportService struct {
serverURL string
apiKey string
logger *logging.Logger
}
// NewExportService creates a new export service
func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportService {
return &ExportService{
serverURL: serverURL,
apiKey: apiKey,
logger: logger,
}
}
// ExportJob exports a job with optional anonymization
// Returns the path to the exported file
func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) {
s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize)
// Placeholder - actual implementation would call API
// POST /api/jobs/{id}/export?anonymize=true
exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix())
s.logger.Info("export complete", "job", jobName, "path", exportPath)
return exportPath, nil
}
// ExportOptions contains options for export
type ExportOptions struct {
Anonymize bool
IncludeLogs bool
IncludeData bool
}

View file

@ -4,8 +4,12 @@ package services
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/network"
@ -21,6 +25,7 @@ type TaskQueue struct {
*queue.TaskQueue // Embed to inherit all queue methods directly
expManager *experiment.Manager
ctx context.Context
config *config.Config
}
// NewTaskQueue creates a new task queue service
@ -37,14 +42,17 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
return nil, fmt.Errorf("failed to create task queue: %w", err)
}
// Initialize experiment manager
// TODO: Get base path from config
expManager := experiment.NewManager("./experiments")
// Initialize experiment manager with proper path
// BasePath already includes the mode-based experiments path (e.g., ./data/dev/experiments)
expDir := cfg.BasePath
os.MkdirAll(expDir, 0755)
expManager := experiment.NewManager(expDir)
return &TaskQueue{
TaskQueue: internalQueue,
expManager: expManager,
ctx: context.Background(),
config: cfg,
}, nil
}
@ -94,20 +102,38 @@ func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
return map[string]string{}, nil
}
// ListDatasets retrieves available datasets (TUI-specific: currently returns empty)
func (tq *TaskQueue) ListDatasets() ([]struct {
Name string
SizeBytes int64
Location string
LastAccess string
}, error) {
// This method doesn't exist in internal queue, return empty for now
return []struct {
Name string
SizeBytes int64
Location string
LastAccess string
}{}, nil
// ListDatasets retrieves available datasets from the filesystem
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
var datasets []model.DatasetInfo
// Scan the active data directory for datasets
dataDir := tq.config.BasePath
if dataDir == "" {
return datasets, nil
}
entries, err := os.ReadDir(dataDir)
if err != nil {
// Directory might not exist yet, return empty
return datasets, nil
}
for _, entry := range entries {
if entry.IsDir() {
info, err := entry.Info()
if err != nil {
continue
}
datasets = append(datasets, model.DatasetInfo{
Name: entry.Name(),
SizeBytes: info.Size(),
Location: filepath.Join(dataDir, entry.Name()),
LastAccess: time.Now(),
})
}
}
return datasets, nil
}
// ListExperiments retrieves experiment list

View file

@ -0,0 +1,275 @@
// Package services provides TUI service clients
package services
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// WebSocketClient manages real-time updates from the server
type WebSocketClient struct {
conn *websocket.Conn
serverURL string
apiKey string
logger *logging.Logger
// Channels for different update types
jobUpdates chan model.JobUpdateMsg
gpuUpdates chan model.GPUUpdateMsg
statusUpdates chan model.StatusMsg
// Control
ctx context.Context
cancel context.CancelFunc
connected bool
}
// JobUpdateMsg represents a real-time job status update
type JobUpdateMsg struct {
JobName string `json:"job_name"`
Status string `json:"status"`
TaskID string `json:"task_id"`
Progress int `json:"progress"`
}
// GPUUpdateMsg represents a real-time GPU status update
type GPUUpdateMsg struct {
DeviceID int `json:"device_id"`
Utilization int `json:"utilization"`
MemoryUsed int64 `json:"memory_used"`
MemoryTotal int64 `json:"memory_total"`
Temperature int `json:"temperature"`
}
// NewWebSocketClient creates a new WebSocket client
func NewWebSocketClient(serverURL, apiKey string, logger *logging.Logger) *WebSocketClient {
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketClient{
serverURL: serverURL,
apiKey: apiKey,
logger: logger,
jobUpdates: make(chan model.JobUpdateMsg, 100),
gpuUpdates: make(chan model.GPUUpdateMsg, 100),
statusUpdates: make(chan model.StatusMsg, 100),
ctx: ctx,
cancel: cancel,
}
}
// Connect establishes the WebSocket connection
func (c *WebSocketClient) Connect() error {
// Parse server URL and construct WebSocket URL
u, err := url.Parse(c.serverURL)
if err != nil {
return fmt.Errorf("invalid server URL: %w", err)
}
// Convert http/https to ws/wss
wsScheme := "ws"
if u.Scheme == "https" {
wsScheme = "wss"
}
wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host)
// Create dialer with timeout
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
Subprotocols: []string{"fetchml-v1"},
}
// Add API key to headers
headers := http.Header{}
if c.apiKey != "" {
headers.Set("X-API-Key", c.apiKey)
}
conn, resp, err := dialer.Dial(wsURL, headers)
if err != nil {
if resp != nil {
return fmt.Errorf("websocket dial failed (status %d): %w", resp.StatusCode, err)
}
return fmt.Errorf("websocket dial failed: %w", err)
}
c.conn = conn
c.connected = true
c.logger.Info("websocket connected", "url", wsURL)
// Start message handler
go c.messageHandler()
// Start heartbeat
go c.heartbeat()
return nil
}
// Disconnect closes the WebSocket connection
func (c *WebSocketClient) Disconnect() {
c.cancel()
if c.conn != nil {
c.conn.Close()
}
c.connected = false
}
// IsConnected returns true if connected
func (c *WebSocketClient) IsConnected() bool {
return c.connected
}
// messageHandler reads messages from the WebSocket
func (c *WebSocketClient) messageHandler() {
for {
select {
case <-c.ctx.Done():
return
default:
}
if c.conn == nil {
time.Sleep(100 * time.Millisecond)
continue
}
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Read message
messageType, data, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
c.logger.Error("websocket read error", "error", err)
}
c.connected = false
// Attempt reconnect
time.Sleep(5 * time.Second)
if err := c.Connect(); err != nil {
c.logger.Error("websocket reconnect failed", "error", err)
}
continue
}
// Handle binary vs text messages
if messageType == websocket.BinaryMessage {
c.handleBinaryMessage(data)
} else {
c.handleTextMessage(data)
}
}
}
// handleBinaryMessage handles binary WebSocket messages
func (c *WebSocketClient) handleBinaryMessage(data []byte) {
if len(data) < 2 {
return
}
// Binary protocol: [opcode:1][data...]
opcode := data[0]
payload := data[1:]
switch opcode {
case 0x01: // Job update
var update JobUpdateMsg
if err := json.Unmarshal(payload, &update); err != nil {
c.logger.Error("failed to unmarshal job update", "error", err)
return
}
c.jobUpdates <- model.JobUpdateMsg(update)
case 0x02: // GPU update
var update GPUUpdateMsg
if err := json.Unmarshal(payload, &update); err != nil {
c.logger.Error("failed to unmarshal GPU update", "error", err)
return
}
c.gpuUpdates <- model.GPUUpdateMsg(update)
case 0x03: // Status message
var status model.StatusMsg
if err := json.Unmarshal(payload, &status); err != nil {
c.logger.Error("failed to unmarshal status", "error", err)
return
}
c.statusUpdates <- status
}
}
// handleTextMessage handles text WebSocket messages (JSON)
func (c *WebSocketClient) handleTextMessage(data []byte) {
var msg map[string]interface{}
if err := json.Unmarshal(data, &msg); err != nil {
c.logger.Error("failed to unmarshal text message", "error", err)
return
}
msgType, _ := msg["type"].(string)
switch msgType {
case "job_update":
// Handle JSON job updates
case "gpu_update":
// Handle JSON GPU updates
case "status":
// Handle status messages
}
}
// heartbeat sends periodic ping messages
func (c *WebSocketClient) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
if c.conn != nil {
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.logger.Error("websocket ping failed", "error", err)
c.connected = false
}
}
}
}
}
// Subscribe subscribes to specific update channels
func (c *WebSocketClient) Subscribe(channels ...string) error {
if !c.connected {
return fmt.Errorf("not connected")
}
subMsg := map[string]interface{}{
"action": "subscribe",
"channels": channels,
}
data, _ := json.Marshal(subMsg)
return c.conn.WriteMessage(websocket.TextMessage, data)
}
// GetJobUpdates returns the job updates channel
func (c *WebSocketClient) GetJobUpdates() <-chan model.JobUpdateMsg {
return c.jobUpdates
}
// GetGPUUpdates returns the GPU updates channel
func (c *WebSocketClient) GetGPUUpdates() <-chan model.GPUUpdateMsg {
return c.gpuUpdates
}
// GetStatusUpdates returns the status updates channel
func (c *WebSocketClient) GetStatusUpdates() <-chan model.StatusMsg {
return c.statusUpdates
}

View file

@ -0,0 +1,147 @@
// Package view provides TUI rendering functionality
package view
import (
"strings"
"github.com/charmbracelet/lipgloss"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
)
// NarrativeView displays job research context and outcome
type NarrativeView struct {
Width int
Height int
}
// View renders the narrative/outcome for a job
func (v *NarrativeView) View(job model.Job) string {
if job.Hypothesis == "" && job.Context == "" && job.Intent == "" {
return v.renderEmpty()
}
var sections []string
// Title
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#ff9e64")).
MarginBottom(1)
sections = append(sections, titleStyle.Render("📊 Research Context"))
// Hypothesis
if job.Hypothesis != "" {
sections = append(sections, v.renderSection("🧪 Hypothesis", job.Hypothesis))
}
// Context
if job.Context != "" {
sections = append(sections, v.renderSection("📚 Context", job.Context))
}
// Intent
if job.Intent != "" {
sections = append(sections, v.renderSection("🎯 Intent", job.Intent))
}
// Expected Outcome
if job.ExpectedOutcome != "" {
sections = append(sections, v.renderSection("📈 Expected Outcome", job.ExpectedOutcome))
}
// Actual Outcome (if available)
if job.ActualOutcome != "" {
statusIcon := "❓"
switch job.OutcomeStatus {
case "validated":
statusIcon = "✅"
case "invalidated":
statusIcon = "❌"
case "inconclusive":
statusIcon = "⚖️"
case "partial":
statusIcon = "📝"
}
sections = append(sections, v.renderSection(statusIcon+" Actual Outcome", job.ActualOutcome))
}
// Combine all sections
content := lipgloss.JoinVertical(lipgloss.Left, sections...)
// Apply border and padding
style := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7aa2f7")).
Padding(1, 2).
Width(v.Width - 4)
return style.Render(content)
}
func (v *NarrativeView) renderSection(title, content string) string {
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7aa2f7"))
contentStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#d8dee9")).
MarginLeft(2)
// Wrap long content
wrapped := wrapText(content, v.Width-12)
return lipgloss.JoinVertical(lipgloss.Left,
titleStyle.Render(title),
contentStyle.Render(wrapped),
"",
)
}
func (v *NarrativeView) renderEmpty() string {
style := lipgloss.NewStyle().
Foreground(lipgloss.Color("#88c0d0")).
Italic(true).
Padding(2)
return style.Render("No narrative data available for this job.\n\n" +
"Use --hypothesis, --context, --intent flags when queuing jobs.")
}
// wrapText wraps text to maxWidth, preserving newlines
func wrapText(text string, maxWidth int) string {
if maxWidth <= 0 {
return text
}
lines := strings.Split(text, "\n")
var result []string
for _, line := range lines {
if len(line) <= maxWidth {
result = append(result, line)
continue
}
// Simple word wrap
words := strings.Fields(line)
var currentLine string
for _, word := range words {
if len(currentLine)+len(word)+1 > maxWidth {
if currentLine != "" {
result = append(result, currentLine)
}
currentLine = word
} else {
if currentLine != "" {
currentLine += " "
}
currentLine += word
}
}
if currentLine != "" {
result = append(result, currentLine)
}
}
return strings.Join(result, "\n")
}

View file

@ -2,6 +2,7 @@
package view
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
@ -180,6 +181,27 @@ func getRightPanel(m model.State, width int) string {
style = activeBorderStyle
viewTitle = "📦 Datasets"
content = m.DatasetView.View()
case model.ViewModeNarrative:
style = activeBorderStyle
viewTitle = "📊 Research Context"
narrativeView := &NarrativeView{Width: width, Height: m.Height}
content = narrativeView.View(m.SelectedJob)
case model.ViewModeTeam:
style = activeBorderStyle
viewTitle = "👥 Team Jobs"
content = "Team collaboration view - shows jobs from all team members\n\n(Requires API: GET /api/jobs?all_users=true)"
case model.ViewModeExperimentHistory:
style = activeBorderStyle
viewTitle = "📜 Experiment History"
content = m.ExperimentHistoryView.View()
case model.ViewModeConfig:
style = activeBorderStyle
viewTitle = "⚙️ Config"
content = m.ConfigView.View()
case model.ViewModeLogs:
style = activeBorderStyle
viewTitle = "📜 Logs"
content = m.LogsView.View()
default:
viewTitle = "📊 System Overview"
content = getOverviewPanel(m)
@ -218,7 +240,14 @@ func getStatusBar(m model.State) string {
if m.ShowHelp {
statusText = "Press 'h' to hide help"
}
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText)
// Add refresh rate indicator
refreshInfo := ""
if m.RefreshRate > 0 {
refreshInfo = fmt.Sprintf(" | %.0fms", m.RefreshRate)
}
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText + refreshInfo)
}
func helpText(m model.State) string {
@ -242,25 +271,28 @@ func helpText(m model.State) string {
Navigation
j/k, / : Move selection / : Filter jobs
1 : Job list view 2 : Datasets view
3 : Experiments view v : Queue view
3 : Experiments view q : Queue view
g : GPU view o : Container view
s : Settings view
s : Settings view n : Narrative view
m : Team view e : Experiment history
c : Config view l : Logs (job stream)
E : Export job @ : Filter by team
Actions
t : Queue job a : Queue w/ args
c : Cancel task d : Delete pending
x : Cancel task d : Delete pending
f : Mark as failed r : Refresh all
G : Refresh GPU only
General
h or ? : Toggle this help q/Ctrl+C : Quit
h or ? : Toggle this help Ctrl+C : Quit
`
}
func getQuickHelp(m model.State) string {
if m.ActiveView == model.ViewModeSettings {
return " ↑/↓:move enter:select esc:exit settings q:quit"
return " ↑/↓:move enter:select esc:exit settings Ctrl+C:quit"
}
return " h:help 1:jobs 2:datasets 3:experiments v:queue g:gpu o:containers " +
"s:settings t:queue r:refresh q:quit"
return " h:help 1:jobs 2:datasets 3:experiments q:queue g:gpu o:containers " +
"n:narrative m:team e:history c:config l:logs E:export @:filter s:settings t:trigger r:refresh Ctrl+C:quit"
}

View file

@ -5,6 +5,7 @@ import (
"log"
"os"
"os/signal"
"path/filepath"
"syscall"
tea "github.com/charmbracelet/bubbletea"
@ -41,6 +42,16 @@ func (m AppModel) View() string {
}
func main() {
// Redirect logs to file to prevent TUI disruption
homeDir, _ := os.UserHomeDir()
logDir := filepath.Join(homeDir, ".ml", "logs")
os.MkdirAll(logDir, 0755)
logFile, logErr := os.OpenFile(filepath.Join(logDir, "tui.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if logErr == nil {
log.SetOutput(logFile)
defer logFile.Close()
}
// Parse authentication flags
authFlags := auth.ParseAuthFlags()
if err := auth.ValidateFlags(authFlags); err != nil {
@ -85,10 +96,10 @@ func main() {
log.Printf(" 4. Run TUI: ./bin/tui")
log.Printf("")
log.Printf("Example ~/.ml/config.toml:")
log.Printf(" worker_host = \"localhost\"")
log.Printf(" worker_user = \"your_username\"")
log.Printf(" worker_base = \"~/ml_jobs\"")
log.Printf(" worker_port = 22")
log.Printf(" mode = \"dev\"")
log.Printf(" # Paths auto-resolve based on mode:")
log.Printf(" # dev mode: ./data/dev/experiments")
log.Printf(" # prod mode: ./data/prod/experiments")
log.Printf(" api_key = \"your_api_key_here\"")
log.Printf("")
log.Printf("For more help, see: https://github.com/jfraeys/fetch_ml/docs")
@ -101,6 +112,11 @@ func main() {
}
log.Printf("Loaded TOML configuration from %s", cliConfPath)
// Force local mode - TUI runs on server with direct filesystem access
cfg.Host = ""
// Clear BasePath to force mode-based path resolution
cfg.BasePath = ""
// Validate authentication configuration
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
log.Fatalf("Invalid authentication configuration: %v", err)
@ -130,7 +146,7 @@ func main() {
srv, err := services.NewMLServer(cfg)
if err != nil {
log.Fatalf("Failed to connect to server: %v", err)
log.Fatalf("Failed to initialize local server: %v", err)
}
defer func() {
if err := srv.Close(); err != nil {
@ -138,19 +154,24 @@ func main() {
}
}()
// TaskQueue is optional for local mode
tq, err := services.NewTaskQueue(cfg)
if err != nil {
log.Printf("Failed to connect to Redis: %v", err)
return
log.Printf("Warning: Failed to connect to Redis: %v", err)
log.Printf("Continuing without task queue functionality")
tq = nil
}
if tq != nil {
defer func() {
if err := tq.Close(); err != nil {
log.Printf("task queue close error: %v", err)
}
}()
}
// Initialize logger with error level and no debug output
logger := logging.NewLogger(-4, false) // -4 = slog.LevelError
// Initialize logger with file output only (prevents TUI disruption)
logFilePath := filepath.Join(logDir, "tui.log")
logger := logging.NewFileLogger(-4, false, logFilePath) // -4 = slog.LevelError
// Initialize State and Controller
var effectiveAPIKey string
@ -177,14 +198,17 @@ func main() {
go func() {
<-sigChan
logger.Info("Received shutdown signal, closing TUI...")
p.Quit()
}()
if _, err := p.Run(); err != nil {
_ = p.ReleaseTerminal()
log.Printf("Error running TUI: %v", err)
logger.Error("Error running TUI", "error", err)
return
}
// Ensure terminal is released and resources are closed via defer statements
_ = p.ReleaseTerminal()
logger.Info("TUI shutdown complete")
}

View file

@ -19,7 +19,7 @@ security:
api_key_rotation_days: 90
audit_logging:
enabled: true
log_path: "/tmp/fetchml-audit.log"
log_path: "data/dev/logs/fetchml-audit.log"
rate_limit:
enabled: false
requests_per_minute: 60
@ -46,8 +46,8 @@ database:
logging:
level: "info"
file: ""
audit_log: ""
file: "data/dev/logs/fetchml.log"
audit_log: "data/dev/logs/fetchml-audit.log"
resources:
max_workers: 1

View file

@ -1,6 +1,6 @@
base_path: "./data/experiments"
base_path: "./data/dev/experiments"
data_dir: "./data/active"
data_dir: "./data/dev/active"
auth:
enabled: false
@ -19,7 +19,7 @@ security:
api_key_rotation_days: 90
audit_logging:
enabled: true
log_path: "./data/fetchml-audit.log"
log_path: "./data/dev/logs/fetchml-audit.log"
rate_limit:
enabled: false
requests_per_minute: 60
@ -42,12 +42,12 @@ redis:
database:
type: "sqlite"
connection: "./data/fetchml.sqlite"
connection: "./data/dev/fetchml.sqlite"
logging:
level: "info"
file: ""
audit_log: ""
file: "./data/dev/logs/fetchml.log"
audit_log: "./data/dev/logs/fetchml-audit.log"
resources:
max_workers: 1

View file

@ -1,6 +1,6 @@
base_path: "/app/data/experiments"
base_path: "/app/data/prod/experiments"
data_dir: "/data/active"
data_dir: "/app/data/prod/active"
auth:
enabled: true
@ -45,12 +45,12 @@ redis:
database:
type: "sqlite"
connection: "/app/data/experiments/fetch_ml.sqlite"
connection: "/app/data/prod/fetch_ml.sqlite"
logging:
level: "info"
file: "/logs/fetch_ml.log"
audit_log: ""
file: "/app/data/prod/logs/fetch_ml.log"
audit_log: "/app/data/prod/logs/audit.log"
resources:
max_workers: 2

View file

@ -36,6 +36,22 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
// Register HTTP API handlers
s.handlers.RegisterHandlers(mux)
// Register new REST API endpoints for TUI
jobsHandler := jobs.NewHandler(
s.expManager,
s.logger,
s.taskQueue,
s.db,
s.config.BuildAuthConfig(),
nil,
)
// Experiment history endpoint: GET /api/experiments/:id/history
mux.HandleFunc("GET /api/experiments/{id}/history", jobsHandler.GetExperimentHistoryHTTP)
// Team jobs endpoint: GET /api/jobs?all_users=true
mux.HandleFunc("GET /api/jobs", jobsHandler.ListAllJobsHTTP)
}
// registerHealthRoutes sets up health check endpoints

View file

@ -12,6 +12,8 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/audit"
@ -110,6 +112,22 @@ const (
PermJupyterRead = "jupyter:read"
)
// ClientType represents the type of WebSocket client
type ClientType int
const (
ClientTypeCLI ClientType = iota
ClientTypeTUI
)
// Client represents a connected WebSocket client
type Client struct {
conn *websocket.Conn
Type ClientType
User string
RemoteAddr string
}
// Handler provides WebSocket handling
type Handler struct {
authConfig *auth.Config
@ -125,6 +143,10 @@ type Handler struct {
jobsHandler *jobs.Handler
jupyterHandler *jupyterj.Handler
datasetsHandler *datasets.Handler
// Client management for push updates
clients map[*Client]bool
clientsMu sync.RWMutex
}
// NewHandler creates a new WebSocket handler
@ -158,6 +180,7 @@ func NewHandler(
jobsHandler: jobsHandler,
jupyterHandler: jupyterHandler,
datasetsHandler: datasetsHandler,
clients: make(map[*Client]bool),
}
}
@ -217,6 +240,25 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleConnection(conn *websocket.Conn) {
h.logger.Info("websocket connection established", "remote", conn.RemoteAddr())
// Register client
client := &Client{
conn: conn,
Type: ClientTypeTUI, // Assume TUI for now, could detect from handshake
User: "tui-user",
RemoteAddr: conn.RemoteAddr().String(),
}
h.clientsMu.Lock()
h.clients[client] = true
h.clientsMu.Unlock()
defer func() {
h.clientsMu.Lock()
delete(h.clients, client)
h.clientsMu.Unlock()
conn.Close()
}()
for {
messageType, payload, err := conn.ReadMessage()
if err != nil {
@ -765,3 +807,66 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
"message": "Outcome updated",
})
}
// BroadcastJobUpdate sends job status update to all connected TUI clients
func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
msg := map[string]any{
"type": "job_update",
"job_name": jobName,
"status": status,
"progress": progress,
"time": time.Now().Unix(),
}
payload, _ := json.Marshal(msg)
for client := range h.clients {
if client.Type == ClientTypeTUI {
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
h.logger.Warn("failed to broadcast to client", "error", err, "client", client.RemoteAddr)
}
}
}
}
// BroadcastGPUUpdate sends GPU status update to all connected TUI clients
func (h *Handler) BroadcastGPUUpdate(deviceID, utilization int, memoryUsed, memoryTotal int64) {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
msg := map[string]any{
"type": "gpu_update",
"device_id": deviceID,
"utilization": utilization,
"memory_used": memoryUsed,
"memory_total": memoryTotal,
"time": time.Now().Unix(),
}
payload, _ := json.Marshal(msg)
for client := range h.clients {
if client.Type == ClientTypeTUI {
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
h.logger.Warn("failed to broadcast GPU update", "error", err, "client", client.RemoteAddr)
}
}
}
}
// GetConnectedClientCount returns the number of connected TUI clients
func (h *Handler) GetConnectedClientCount() int {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
count := 0
for client := range h.clients {
if client.Type == ClientTypeTUI {
count++
}
}
return count
}

View file

@ -229,6 +229,37 @@ func (s *SmartDefaults) ExpandPath(path string) string {
return path
}
// ModeBasedPath returns the appropriate path for a given mode
func ModeBasedPath(mode string, subpath string) string {
switch mode {
case "dev":
return filepath.Join("./data/dev", subpath)
case "prod":
return filepath.Join("./data/prod", subpath)
case "ci":
return filepath.Join("./data/ci", subpath)
case "prod-smoke":
return filepath.Join("./data/prod-smoke", subpath)
default:
return filepath.Join("./data/dev", subpath)
}
}
// ModeBasedBasePath returns the experiments base path for a given mode
func ModeBasedBasePath(mode string) string {
return ModeBasedPath(mode, "experiments")
}
// ModeBasedDataDir returns the data directory for a given mode
func ModeBasedDataDir(mode string) string {
return ModeBasedPath(mode, "active")
}
// ModeBasedLogDir returns the logs directory for a given mode
func ModeBasedLogDir(mode string) string {
return ModeBasedPath(mode, "logs")
}
// GetEnvironmentDescription returns a human-readable description
func (s *SmartDefaults) GetEnvironmentDescription() string {
switch s.Profile {