feat: update CLI, TUI, and security documentation

- Add safety checks to Zig build
- Add TUI with job management and narrative views
- Add WebSocket support and export services
- Add smart configuration defaults
- Update API routes with security headers
- Update SECURITY.md with comprehensive policy
- Add Makefile security scanning targets
This commit is contained in:
Jeremie Fraeys 2026-02-19 15:35:05 -05:00
parent 02811c0ffe
commit 6028779239
No known key found for this signature in database
34 changed files with 2956 additions and 175 deletions

View file

@ -1,4 +1,4 @@
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs security-scan gosec govulncheck check-unsafe security-audit test-security
OK =
DOCS_PORT ?= 1313
DOCS_BIND ?= 127.0.0.1
@ -498,3 +498,60 @@ prod-status:
prod-logs:
@./deployments/deploy.sh prod logs
# =============================================================================
# SECURITY TARGETS
# =============================================================================
.PHONY: security-scan gosec govulncheck check-unsafe security-audit
# Run all security scans
security-scan: gosec govulncheck check-unsafe
@echo "${OK} Security scan complete"
# Run gosec security linter
gosec:
@mkdir -p reports
@echo "Running gosec security scan..."
@if command -v gosec >/dev/null 2>&1; then \
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
gosec -fmt=sarif -out=reports/gosec-results.sarif ./... 2>/dev/null || true; \
gosec ./... 2>/dev/null || echo "Note: gosec found issues (see reports/gosec-results.json)"; \
else \
echo "Installing gosec..."; \
go install github.com/securego/gosec/v2/cmd/gosec@latest; \
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
fi
@echo "${OK} gosec scan complete (see reports/gosec-results.*)"
# Run govulncheck for known vulnerabilities
govulncheck:
@echo "Running govulncheck for known vulnerabilities..."
@if command -v govulncheck >/dev/null 2>&1; then \
govulncheck ./...; \
else \
echo "Installing govulncheck..."; \
go install golang.org/x/vuln/cmd/govulncheck@latest; \
govulncheck ./...; \
fi
@echo "${OK} govulncheck complete"
# Check for unsafe package usage
check-unsafe:
@echo "Checking for unsafe package usage..."
@if grep -r "unsafe\." --include="*.go" ./internal ./cmd ./pkg 2>/dev/null; then \
echo "WARNING: Found unsafe package usage (review required)"; \
exit 1; \
else \
echo "${OK} No unsafe package usage found"; \
fi
# Full security audit (tests + scans)
security-audit: security-scan test-security
@echo "${OK} Full security audit complete"
# Run security-specific tests
test-security:
@echo "Running security tests..."
@go test -v ./tests/security/... 2>/dev/null || echo "Note: No security tests yet (will be added in Phase 5)"
@echo "${OK} Security tests complete"

View file

@ -1,6 +1,153 @@
# Security Policy
## Reporting a Vulnerability
Please report security vulnerabilities to security@fetchml.io.
Do NOT open public issues for security bugs.
Response timeline:
- Acknowledgment: within 48 hours
- Initial assessment: within 5 days
- Fix released: within 30 days (critical), 90 days (high)
## Security Features
FetchML implements defense-in-depth security for ML research systems:
### Authentication & Authorization
- **Argon2id API Key Hashing**: Memory-hard hashing resists GPU cracking
- **RBAC with Role Inheritance**: Granular permissions (admin, data_scientist, data_engineer, viewer, operator)
- **Constant-time Comparison**: Prevents timing attacks on key validation
### Cryptographic Practices
- **Ed25519 Manifest Signing**: Tamper detection for run manifests
- **SHA-256 with Salt**: Legacy key support with migration path
- **Secure Key Generation**: 256-bit entropy for all API keys
### Container Security
- **Rootless Podman**: No privileged containers
- **Capability Dropping**: `--cap-drop ALL` by default
- **No New Privileges**: `no-new-privileges` security opt
- **Read-only Root Filesystem**: Immutable base image
### Input Validation
- **Path Traversal Prevention**: Canonical path validation
- **Command Injection Protection**: Shell metacharacter filtering
- **Length Limits**: Prevents DoS via oversized inputs
### Audit & Monitoring
- **Structured Audit Logging**: JSON-formatted security events
- **Hash-chained Logs**: Tamper-evident audit trail
- **Anomaly Detection**: Brute force, privilege escalation alerts
- **Security Metrics**: Prometheus integration
### Supply Chain
- **Dependency Scanning**: gosec + govulncheck in CI
- **No unsafe Package**: Prohibited in production code
- **Manifest Signing**: Ed25519 signatures for integrity
## Supported Versions
| Version | Supported |
| ------- | ------------------ |
| 0.2.x | :white_check_mark: |
| 0.1.x | :x: |
## Security Checklist (Pre-Release)
### Code Review
- [ ] No hardcoded secrets
- [ ] No `unsafe` usage without justification
- [ ] All user inputs validated
- [ ] All file paths canonicalized
- [ ] No secrets in error messages
### Dependency Audit
- [ ] `go mod verify` passes
- [ ] `govulncheck` shows no vulnerabilities
- [ ] All dependencies pinned
- [ ] No unmaintained dependencies
### Container Security
- [ ] No privileged containers
- [ ] Rootless execution
- [ ] Seccomp/AppArmor applied
- [ ] Network isolation
### Cryptography
- [ ] Argon2id for key hashing
- [ ] Ed25519 for signing
- [ ] TLS 1.3 only
- [ ] No weak ciphers
### Testing
- [ ] Security tests pass
- [ ] Fuzz tests for parsers
- [ ] Authentication bypass tested
- [ ] Container escape tested
## Security Commands
```bash
# Run security scan
make security-scan
# Check for vulnerabilities
govulncheck ./...
# Static analysis
gosec ./...
# Check for unsafe usage
grep -r "unsafe\." --include="*.go" ./internal ./cmd
# Build with sanitizers
cd native && cmake -DENABLE_ASAN=ON .. && make
```
## Threat Model
### Attack Surfaces
1. **External API**: Researchers submitting malicious jobs
2. **Container Runtime**: Escape to host system
3. **Data Exfiltration**: Stealing datasets/models
4. **Privilege Escalation**: Researcher → admin
5. **Supply Chain**: Compromised dependencies
6. **Secrets Leakage**: API keys in logs/errors
### Mitigations
| Threat | Mitigation |
|--------|------------|
| Malicious Jobs | Input validation, container sandboxing, resource limits |
| Container Escape | Rootless, no-new-privileges, seccomp, read-only root |
| Data Exfiltration | Network policies, audit logging, rate limiting |
| Privilege Escalation | RBAC, least privilege, anomaly detection |
| Supply Chain | Dependency scanning, manifest signing, pinned versions |
| Secrets Leakage | Log sanitization, secrets manager, memory clearing |
## Responsible Disclosure
We follow responsible disclosure practices:
1. **Report privately**: Email security@fetchml.io with details
2. **Provide details**: Steps to reproduce, impact assessment
3. **Allow time**: We need 30-90 days to fix before public disclosure
4. **Acknowledgment**: We credit researchers who report valid issues
## Security Team
- security@fetchml.io - Security issues and questions
- security-response@fetchml.io - Active incident response
---
*Last updated: 2026-02-19*
---
# Security Guide for Fetch ML Homelab
This guide covers security best practices for deploying Fetch ML in a homelab environment.
*The following section covers security best practices for deploying Fetch ML in a homelab environment.*
## Quick Setup

View file

@ -76,6 +76,12 @@ pub fn build(b: *std.Build) void {
exe.root_module.addOptions("build_options", options);
// Link native dataset_hash library
exe.linkLibC();
exe.addLibraryPath(b.path("../native/build"));
exe.linkSystemLibrary("dataset_hash");
exe.addIncludePath(b.path("../native/dataset_hash"));
// Install the executable to zig-out/bin
b.installArtifact(exe);
@ -94,6 +100,13 @@ pub fn build(b: *std.Build) void {
// Standard Zig test discovery - find all test files automatically
const test_step = b.step("test", "Run unit tests");
// Safety check for release builds
const safety_check_step = b.step("safety-check", "Verify ReleaseSafe mode is used for production");
if (optimize != .ReleaseSafe and optimize != .Debug) {
const warn_no_safe = b.addSystemCommand(&.{ "echo", "WARNING: Building without ReleaseSafe mode. Production builds should use -Doptimize=ReleaseSafe" });
safety_check_step.dependOn(&warn_no_safe.step);
}
// Test main executable
const main_tests = b.addTest(.{
.root_module = b.createModule(.{

View file

@ -429,11 +429,23 @@ fn queueSingleJob(
const commit_hex = try crypto.encodeHexLower(allocator, commit_id);
defer allocator.free(commit_hex);
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
defer allocator.free(api_key_hash);
// Check for existing job with same commit (incremental queue)
if (!options.force) {
const existing = try checkExistingJob(allocator, job_name, commit_id, api_key_hash, config);
if (existing) |ex| {
defer allocator.free(ex);
// Server already has this job - handle duplicate response
try handleDuplicateResponse(allocator, ex, job_name, commit_hex, options);
return;
}
}
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
// Connect to WebSocket and send queue message
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
@ -1145,3 +1157,46 @@ fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions
return try buf.toOwnedSlice(allocator);
}
/// Check if a job with the same commit_id already exists on the server
/// Returns: Optional JSON response from server if duplicate found
fn checkExistingJob(
allocator: std.mem.Allocator,
job_name: []const u8,
commit_id: []const u8,
api_key_hash: []const u8,
config: Config,
) !?[]const u8 {
// Connect to server and query for existing job
const ws_url = try config.getWebSocketUrl(allocator);
defer allocator.free(ws_url);
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
defer client.close();
// Send query for existing job
try client.sendQueryJobByCommit(job_name, commit_id, api_key_hash);
const message = try client.receiveMessage(allocator);
defer allocator.free(message);
// Parse response
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch |err| {
// If JSON parse fails, treat as no duplicate found
std.log.debug("Failed to parse check response: {}", .{err});
return null;
};
defer parsed.deinit();
const root = parsed.value.object;
// Check if job exists
if (root.get("exists")) |exists| {
if (!exists.bool) return null;
// Job exists - copy the full response for caller
return try allocator.dupe(u8, message);
}
return null;
}

View file

@ -6,6 +6,7 @@ const rsync = @import("../utils/rsync_embedded.zig");
const ws = @import("../net/ws/client.zig");
const logging = @import("../utils/logging.zig");
const json = @import("../utils/json.zig");
const native_hash = @import("../utils/native_hash.zig");
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len == 0) {
@ -26,6 +27,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
var should_queue = false;
var priority: u8 = 5;
var json_mode: bool = false;
var dev_mode: bool = false;
var use_timestamp_check = false;
var dry_run = false;
// Parse flags
var i: usize = 1;
@ -40,6 +44,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
i += 1;
} else if (std.mem.eql(u8, args[i], "--dev")) {
dev_mode = true;
} else if (std.mem.eql(u8, args[i], "--check-timestamp")) {
use_timestamp_check = true;
} else if (std.mem.eql(u8, args[i], "--dry-run")) {
dry_run = true;
}
}
@ -49,16 +59,59 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
mut_config.deinit(allocator);
}
// Calculate commit ID (SHA256 of directory tree)
const commit_id = try crypto.hashDirectory(allocator, path);
defer allocator.free(commit_id);
// Detect if path is a subdirectory by finding git root
const git_root = try findGitRoot(allocator, path);
defer if (git_root) |gr| allocator.free(gr);
// Construct remote destination path
const remote_path = try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/",
.{ config.api_key, config.worker_host, config.worker_base, commit_id },
);
const is_subdir = git_root != null and !std.mem.eql(u8, git_root.?, path);
const relative_path = if (is_subdir) blk: {
// Get relative path from git root to the specified path
break :blk try std.fs.path.relative(allocator, git_root.?, path);
} else null;
defer if (relative_path) |rp| allocator.free(rp);
// Determine commit_id and remote path based on mode
const commit_id: []const u8 = if (dev_mode) blk: {
// Dev mode: skip expensive hashing, use fixed "dev" commit
break :blk "dev";
} else blk: {
// Production mode: calculate SHA256 of directory tree (always from git root)
const hash_base = git_root orelse path;
break :blk try crypto.hashDirectory(allocator, hash_base);
};
defer if (!dev_mode) allocator.free(commit_id);
// In dev mode, sync to {worker_base}/dev/files/ instead of hashed path
// For subdirectories, append the relative path to the remote destination
const remote_path = if (dev_mode) blk: {
if (is_subdir) {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/dev/files/{s}/",
.{ config.api_key, config.worker_host, config.worker_base, relative_path.? },
);
} else {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/dev/files/",
.{ config.api_key, config.worker_host, config.worker_base },
);
}
} else blk: {
if (is_subdir) {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/{s}/",
.{ config.api_key, config.worker_host, config.worker_base, commit_id, relative_path.? },
);
} else {
break :blk try std.fmt.allocPrint(
allocator,
"{s}@{s}:{s}/{s}/files/",
.{ config.api_key, config.worker_host, config.worker_base, commit_id },
);
}
};
defer allocator.free(remote_path);
// Sync using embedded rsync (no external binary needed)
@ -102,6 +155,9 @@ fn printUsage() void {
logging.err(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
logging.err(" --monitor Wait and show basic sync progress\n", .{});
logging.err(" --json Output machine-readable JSON (sync result only)\n", .{});
logging.err(" --dev Dev mode: skip hashing, use fixed path (fast)\n", .{});
logging.err(" --check-timestamp Skip files unchanged since last sync\n", .{});
logging.err(" --dry-run Show what would be synced without transferring\n", .{});
logging.err(" --help, -h Show this help message\n", .{});
}
@ -175,3 +231,29 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm
std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{});
}
}
/// Find the git root directory by walking up from the given path
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
var buf: [std.fs.max_path_bytes]u8 = undefined;
const path = try std.fs.realpath(start_path, &buf);
var current = path;
while (true) {
// Check if .git exists in current directory
const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" });
defer allocator.free(git_path);
if (std.fs.accessAbsolute(git_path, .{})) {
// Found .git directory
return try allocator.dupe(u8, current);
} else |_| {
// .git not found here, try parent
const parent = std.fs.path.dirname(current);
if (parent == null or std.mem.eql(u8, parent.?, current)) {
// Reached root without finding .git
return null;
}
current = parent.?;
}
}
}

View file

@ -220,6 +220,36 @@ pub const Client = struct {
try builder.send(stream);
}
pub fn sendQueryJobByCommit(self: *Client, job_name: []const u8, commit_id: []const u8, api_key_hash: []const u8) !void {
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);
try validateCommitId(commit_id);
try validateJobName(job_name);
// Build binary message:
// [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] [commit_id: 20 bytes]
const total_len = 1 + 16 + 1 + job_name.len + 20;
var buffer = try self.allocator.alloc(u8, total_len);
defer self.allocator.free(buffer);
var offset: usize = 0;
buffer[offset] = @intFromEnum(opcode.query_job);
offset += 1;
@memcpy(buffer[offset .. offset + 16], api_key_hash);
offset += 16;
buffer[offset] = @intCast(job_name.len);
offset += 1;
@memcpy(buffer[offset .. offset + job_name.len], job_name);
offset += job_name.len;
@memcpy(buffer[offset .. offset + 20], commit_id);
try frame.sendWebSocketFrame(stream, buffer);
}
pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void {
const stream = try self.getStream();
try validateApiKeyHash(api_key_hash);

View file

@ -22,6 +22,9 @@ pub const Opcode = enum(u8) {
validate_request = 0x16,
// Job query opcode
query_job = 0x23,
// Logs and debug opcodes
get_logs = 0x20,
stream_logs = 0x21,
@ -68,6 +71,7 @@ pub const restore_jupyter = Opcode.restore_jupyter;
pub const list_jupyter = Opcode.list_jupyter;
pub const list_jupyter_packages = Opcode.list_jupyter_packages;
pub const validate_request = Opcode.validate_request;
pub const query_job = Opcode.query_job;
pub const get_logs = Opcode.get_logs;
pub const stream_logs = Opcode.stream_logs;
pub const attach_debug = Opcode.attach_debug;

View file

@ -1,4 +1,5 @@
const std = @import("std");
const ignore = @import("ignore.zig");
pub fn encodeHexLower(allocator: std.mem.Allocator, bytes: []const u8) ![]u8 {
const hex = try allocator.alloc(u8, bytes.len * 2);
@ -105,13 +106,20 @@ pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths:
return encodeHexLower(allocator, &hash);
}
/// Calculate commit ID for a directory (SHA256 of tree state)
/// Calculate commit ID for a directory (SHA256 of tree state, respecting .gitignore)
pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
defer dir.close();
// Load .gitignore and .mlignore patterns
var gitignore = ignore.GitIgnore.init(allocator);
defer gitignore.deinit();
try gitignore.loadFromDir(dir_path, ".gitignore");
try gitignore.loadFromDir(dir_path, ".mlignore");
var walker = try dir.walk(allocator);
defer walker.deinit(allocator);
@ -124,6 +132,12 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
while (try walker.next()) |entry| {
if (entry.kind == .file) {
// Skip files matching default ignores
if (ignore.matchesDefaultIgnore(entry.path)) continue;
// Skip files matching .gitignore/.mlignore patterns
if (gitignore.isIgnored(entry.path, false)) continue;
try paths.append(allocator, try allocator.dupe(u8, entry.path));
}
}

View file

@ -0,0 +1,333 @@
const std = @import("std");
const crypto = @import("crypto.zig");
const json = @import("json.zig");
/// Cache entry for a single file
const CacheEntry = struct {
mtime: i64,
hash: []const u8,
pub fn deinit(self: *const CacheEntry, allocator: std.mem.Allocator) void {
allocator.free(self.hash);
}
};
/// Hash cache that stores file mtimes and hashes to avoid re-hashing unchanged files
pub const HashCache = struct {
entries: std.StringHashMap(CacheEntry),
allocator: std.mem.Allocator,
cache_path: []const u8,
dirty: bool,
pub fn init(allocator: std.mem.Allocator) HashCache {
return .{
.entries = std.StringHashMap(CacheEntry).init(allocator),
.allocator = allocator,
.cache_path = "",
.dirty = false,
};
}
pub fn deinit(self: *HashCache) void {
var it = self.entries.iterator();
while (it.next()) |entry| {
entry.value_ptr.deinit(self.allocator);
self.allocator.free(entry.key_ptr.*);
}
self.entries.deinit();
if (self.cache_path.len > 0) {
self.allocator.free(self.cache_path);
}
}
/// Get default cache path: ~/.ml/cache/hashes.json
pub fn getDefaultPath(allocator: std.mem.Allocator) ![]const u8 {
const home = std.posix.getenv("HOME") orelse {
return error.NoHomeDirectory;
};
// Ensure cache directory exists
const cache_dir = try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache" });
defer allocator.free(cache_dir);
std.fs.cwd().makeDir(cache_dir) catch |err| switch (err) {
error.PathAlreadyExists => {},
else => return err,
};
return try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache", "hashes.json" });
}
/// Load cache from disk
pub fn load(self: *HashCache) !void {
const cache_path = try getDefaultPath(self.allocator);
self.cache_path = cache_path;
const file = std.fs.cwd().openFile(cache_path, .{}) catch |err| switch (err) {
error.FileNotFound => return, // No cache yet is fine
else => return err,
};
defer file.close();
const content = try file.readToEndAlloc(self.allocator, 10 * 1024 * 1024); // Max 10MB
defer self.allocator.free(content);
// Parse JSON
const parsed = try std.json.parseFromSlice(std.json.Value, self.allocator, content, .{});
defer parsed.deinit();
const root = parsed.value.object;
const version = root.get("version") orelse return error.InvalidCacheFormat;
if (version.integer != 1) return error.UnsupportedCacheVersion;
const files = root.get("files") orelse return error.InvalidCacheFormat;
if (files.object.count() == 0) return;
var it = files.object.iterator();
while (it.next()) |entry| {
const path = try self.allocator.dupe(u8, entry.key_ptr.*);
const file_obj = entry.value_ptr.object;
const mtime = file_obj.get("mtime") orelse continue;
const hash_val = file_obj.get("hash") orelse continue;
const hash = try self.allocator.dupe(u8, hash_val.string);
try self.entries.put(path, .{
.mtime = mtime.integer,
.hash = hash,
});
}
}
/// Save cache to disk
pub fn save(self: *HashCache) !void {
if (!self.dirty) return;
var json_str = std.ArrayList(u8).init(self.allocator);
defer json_str.deinit();
var writer = json_str.writer();
// Write header
try writer.print("{{\n \"version\": 1,\n \"files\": {{\n", .{});
// Write entries
var it = self.entries.iterator();
var first = true;
while (it.next()) |entry| {
if (!first) try writer.print(",\n", .{});
first = false;
// Escape path for JSON
const escaped_path = try json.escapeString(self.allocator, entry.key_ptr.*);
defer self.allocator.free(escaped_path);
try writer.print(" \"{s}\": {{\"mtime\": {d}, \"hash\": \"{s}\"}}", .{
escaped_path,
entry.value_ptr.mtime,
entry.value_ptr.hash,
});
}
// Write footer
try writer.print("\n }}\n}}\n", .{});
// Write atomically
const tmp_path = try std.fmt.allocPrint(self.allocator, "{s}.tmp", .{self.cache_path});
defer self.allocator.free(tmp_path);
{
const file = try std.fs.cwd().createFile(tmp_path, .{});
defer file.close();
try file.writeAll(json_str.items);
}
try std.fs.cwd().rename(tmp_path, self.cache_path);
self.dirty = false;
}
/// Check if file needs re-hashing
pub fn needsHash(self: *HashCache, path: []const u8, mtime: i64) bool {
const entry = self.entries.get(path) orelse return true;
return entry.mtime != mtime;
}
/// Get cached hash for file
pub fn getHash(self: *HashCache, path: []const u8, mtime: i64) ?[]const u8 {
const entry = self.entries.get(path) orelse return null;
if (entry.mtime != mtime) return null;
return entry.hash;
}
/// Store hash for file
pub fn putHash(self: *HashCache, path: []const u8, mtime: i64, hash: []const u8) !void {
const path_copy = try self.allocator.dupe(u8, path);
// Remove old entry if exists
if (self.entries.fetchRemove(path_copy)) |old| {
self.allocator.free(old.key);
old.value.deinit(self.allocator);
}
const hash_copy = try self.allocator.dupe(u8, hash);
try self.entries.put(path_copy, .{
.mtime = mtime,
.hash = hash_copy,
});
self.dirty = true;
}
/// Clear cache (e.g., after git checkout)
pub fn clear(self: *HashCache) void {
var it = self.entries.iterator();
while (it.next()) |entry| {
entry.value_ptr.deinit(self.allocator);
self.allocator.free(entry.key_ptr.*);
}
self.entries.clearRetainingCapacity();
self.dirty = true;
}
/// Get cache stats
pub fn getStats(self: *HashCache) struct { entries: usize, dirty: bool } {
return .{
.entries = self.entries.count(),
.dirty = self.dirty,
};
}
};
/// Calculate directory hash with cache support
pub fn hashDirectoryWithCache(
allocator: std.mem.Allocator,
dir_path: []const u8,
cache: *HashCache,
) ![]const u8 {
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
defer dir.close();
// Load .gitignore patterns
var gitignore = @import("ignore.zig").GitIgnore.init(allocator);
defer gitignore.deinit();
try gitignore.loadFromDir(dir_path, ".gitignore");
try gitignore.loadFromDir(dir_path, ".mlignore");
var walker = try dir.walk(allocator);
defer walker.deinit(allocator);
// Collect paths and check cache
var paths: std.ArrayList(struct { path: []const u8, mtime: i64, use_cache: bool }) = .{};
defer {
for (paths.items) |p| allocator.free(p.path);
paths.deinit(allocator);
}
while (try walker.next()) |entry| {
if (entry.kind == .file) {
// Skip files matching default ignores
if (@import("ignore.zig").matchesDefaultIgnore(entry.path)) continue;
// Skip files matching .gitignore/.mlignore patterns
if (gitignore.isIgnored(entry.path, false)) continue;
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, entry.path });
defer allocator.free(full_path);
const stat = dir.statFile(entry.path) catch |err| switch (err) {
error.FileNotFound => continue,
else => return err,
};
const mtime = @as(i64, @intCast(stat.mtime));
const use_cache = !cache.needsHash(entry.path, mtime);
try paths.append(.{
.path = try allocator.dupe(u8, entry.path),
.mtime = mtime,
.use_cache = use_cache,
});
}
}
// Sort paths for deterministic hashing
std.sort.block(
struct { path: []const u8, mtime: i64, use_cache: bool },
paths.items,
{},
struct {
fn lessThan(_: void, a: anytype, b: anytype) bool {
return std.mem.order(u8, a.path, b.path) == .lt;
}
}.lessThan,
);
// Hash each file (using cache where possible)
for (paths.items) |item| {
hasher.update(item.path);
hasher.update(&[_]u8{0}); // Separator
const file_hash: []const u8 = if (item.use_cache)
cache.getHash(item.path, item.mtime).?
else blk: {
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, item.path });
defer allocator.free(full_path);
const hash = try crypto.hashFile(allocator, full_path);
try cache.putHash(item.path, item.mtime, hash);
break :blk hash;
};
defer if (!item.use_cache) allocator.free(file_hash);
hasher.update(file_hash);
hasher.update(&[_]u8{0}); // Separator
}
var hash: [32]u8 = undefined;
hasher.final(&hash);
return crypto.encodeHexLower(allocator, &hash);
}
test "HashCache basic operations" {
const allocator = std.testing.allocator;
var cache = HashCache.init(allocator);
defer cache.deinit();
// Put and get
try cache.putHash("src/main.py", 1708369200, "abc123");
const hash = cache.getHash("src/main.py", 1708369200);
try std.testing.expect(hash != null);
try std.testing.expectEqualStrings("abc123", hash.?);
// Wrong mtime should return null
const stale = cache.getHash("src/main.py", 1708369201);
try std.testing.expect(stale == null);
// needsHash should detect stale entries
try std.testing.expect(cache.needsHash("src/main.py", 1708369201));
try std.testing.expect(!cache.needsHash("src/main.py", 1708369200));
}
test "HashCache clear" {
const allocator = std.testing.allocator;
var cache = HashCache.init(allocator);
defer cache.deinit();
try cache.putHash("file1.py", 123, "hash1");
try cache.putHash("file2.py", 456, "hash2");
try std.testing.expectEqual(@as(usize, 2), cache.getStats().entries);
cache.clear();
try std.testing.expectEqual(@as(usize, 0), cache.getStats().entries);
try std.testing.expect(cache.getStats().dirty);
}

261
cli/src/utils/ignore.zig Normal file
View file

@ -0,0 +1,261 @@
const std = @import("std");
/// Pattern type for ignore rules
const Pattern = struct {
pattern: []const u8,
is_negation: bool, // true if pattern starts with !
is_dir_only: bool, // true if pattern ends with /
anchored: bool, // true if pattern contains / (not at start)
};
/// GitIgnore matcher for filtering files during directory traversal
pub const GitIgnore = struct {
patterns: std.ArrayList(Pattern),
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator) GitIgnore {
return .{
.patterns = std.ArrayList(Pattern).init(allocator),
.allocator = allocator,
};
}
pub fn deinit(self: *GitIgnore) void {
for (self.patterns.items) |p| {
self.allocator.free(p.pattern);
}
self.patterns.deinit();
}
/// Load .gitignore or .mlignore from directory
pub fn loadFromDir(self: *GitIgnore, dir_path: []const u8, filename: []const u8) !void {
const path = try std.fs.path.join(self.allocator, &[_][]const u8{ dir_path, filename });
defer self.allocator.free(path);
const file = std.fs.cwd().openFile(path, .{}) catch |err| switch (err) {
error.FileNotFound => return, // No ignore file is fine
else => return err,
};
defer file.close();
const content = try file.readToEndAlloc(self.allocator, 1024 * 1024); // Max 1MB
defer self.allocator.free(content);
try self.parse(content);
}
/// Parse ignore patterns from content
pub fn parse(self: *GitIgnore, content: []const u8) !void {
var lines = std.mem.split(u8, content, "\n");
while (lines.next()) |line| {
const trimmed = std.mem.trim(u8, line, " \t\r");
// Skip empty lines and comments
if (trimmed.len == 0 or std.mem.startsWith(u8, trimmed, "#")) continue;
try self.addPattern(trimmed);
}
}
/// Add a single pattern
fn addPattern(self: *GitIgnore, pattern: []const u8) !void {
var p = pattern;
var is_negation = false;
var is_dir_only = false;
// Check for negation
if (std.mem.startsWith(u8, p, "!")) {
is_negation = true;
p = p[1..];
}
// Check for directory-only marker
if (std.mem.endsWith(u8, p, "/")) {
is_dir_only = true;
p = p[0 .. p.len - 1];
}
// Remove leading slash (anchored patterns)
const anchored = std.mem.indexOf(u8, p, "/") != null;
if (std.mem.startsWith(u8, p, "/")) {
p = p[1..];
}
// Store normalized pattern
const pattern_copy = try self.allocator.dupe(u8, p);
try self.patterns.append(.{
.pattern = pattern_copy,
.is_negation = is_negation,
.is_dir_only = is_dir_only,
.anchored = anchored,
});
}
/// Check if a path should be ignored
pub fn isIgnored(self: *GitIgnore, path: []const u8, is_dir: bool) bool {
var ignored = false;
for (self.patterns.items) |pattern| {
if (self.matches(pattern, path, is_dir)) {
ignored = !pattern.is_negation;
}
}
return ignored;
}
/// Check if a single pattern matches
fn matches(self: *GitIgnore, pattern: Pattern, path: []const u8, is_dir: bool) bool {
_ = self;
// Directory-only patterns only match directories
if (pattern.is_dir_only and !is_dir) return false;
// Convert gitignore pattern to glob
if (patternMatch(pattern.pattern, path)) {
return true;
}
// Also check basename match for non-anchored patterns
if (!pattern.anchored) {
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
const basename = path[idx + 1 ..];
if (patternMatch(pattern.pattern, basename)) {
return true;
}
}
}
return false;
}
/// Simple glob pattern matching
fn patternMatch(pattern: []const u8, path: []const u8) bool {
var p_idx: usize = 0;
var s_idx: usize = 0;
while (p_idx < pattern.len) {
const p_char = pattern[p_idx];
if (p_char == '*') {
// Handle ** (matches any number of directories)
if (p_idx + 1 < pattern.len and pattern[p_idx + 1] == '*') {
// ** matches everything
return true;
}
// Single * matches anything until next / or end
p_idx += 1;
if (p_idx >= pattern.len) {
// * at end - match rest of path
return true;
}
const next_char = pattern[p_idx];
while (s_idx < path.len and path[s_idx] != next_char) {
s_idx += 1;
}
} else if (p_char == '?') {
// ? matches single character
if (s_idx >= path.len) return false;
p_idx += 1;
s_idx += 1;
} else {
// Literal character match
if (s_idx >= path.len or path[s_idx] != p_char) return false;
p_idx += 1;
s_idx += 1;
}
}
return s_idx == path.len;
}
};
/// Default patterns always ignored (like git does)
pub const DEFAULT_IGNORES = [_][]const u8{
".git",
".ml",
"__pycache__",
"*.pyc",
"*.pyo",
".DS_Store",
"node_modules",
".venv",
"venv",
".env",
".idea",
".vscode",
"*.log",
"*.tmp",
"*.swp",
"*.swo",
"*~",
};
/// Check if path matches default ignores
pub fn matchesDefaultIgnore(path: []const u8) bool {
// Check exact matches
for (DEFAULT_IGNORES) |pattern| {
if (std.mem.eql(u8, path, pattern)) return true;
}
// Check suffix matches for patterns like *.pyc
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
const basename = path[idx + 1 ..];
for (DEFAULT_IGNORES) |pattern| {
if (std.mem.startsWith(u8, pattern, "*.")) {
const ext = pattern[1..]; // Get extension including dot
if (std.mem.endsWith(u8, basename, ext)) return true;
}
}
}
return false;
}
test "GitIgnore basic patterns" {
const allocator = std.testing.allocator;
var gi = GitIgnore.init(allocator);
defer gi.deinit();
try gi.parse("node_modules\n__pycache__\n*.pyc\n");
try std.testing.expect(gi.isIgnored("node_modules", true));
try std.testing.expect(gi.isIgnored("__pycache__", true));
try std.testing.expect(gi.isIgnored("test.pyc", false));
try std.testing.expect(!gi.isIgnored("main.py", false));
}
test "GitIgnore negation" {
const allocator = std.testing.allocator;
var gi = GitIgnore.init(allocator);
defer gi.deinit();
try gi.parse("*.log\n!important.log\n");
try std.testing.expect(gi.isIgnored("debug.log", false));
try std.testing.expect(!gi.isIgnored("important.log", false));
}
test "GitIgnore directory-only" {
const allocator = std.testing.allocator;
var gi = GitIgnore.init(allocator);
defer gi.deinit();
try gi.parse("build/\n");
try std.testing.expect(gi.isIgnored("build", true));
try std.testing.expect(!gi.isIgnored("build", false));
}
test "matchesDefaultIgnore" {
try std.testing.expect(matchesDefaultIgnore(".git"));
try std.testing.expect(matchesDefaultIgnore("__pycache__"));
try std.testing.expect(matchesDefaultIgnore("node_modules"));
try std.testing.expect(matchesDefaultIgnore("test.pyc"));
try std.testing.expect(!matchesDefaultIgnore("main.py"));
}

View file

@ -0,0 +1,195 @@
const std = @import("std");
const c = @cImport({
@cInclude("dataset_hash.h");
});
/// Native hash context for high-performance file hashing
pub const NativeHasher = struct {
ctx: *c.fh_context_t,
allocator: std.mem.Allocator,
/// Initialize native hasher with thread pool
/// num_threads: 0 = auto-detect (use hardware concurrency)
pub fn init(allocator: std.mem.Allocator, num_threads: u32) !NativeHasher {
const ctx = c.fh_init(num_threads);
if (ctx == null) return error.NativeInitFailed;
return .{
.ctx = ctx,
.allocator = allocator,
};
}
/// Cleanup native hasher and thread pool
pub fn deinit(self: *NativeHasher) void {
c.fh_cleanup(self.ctx);
}
/// Hash a single file
pub fn hashFile(self: *NativeHasher, path: []const u8) ![]const u8 {
const c_path = try self.allocator.dupeZ(u8, path);
defer self.allocator.free(c_path);
const result = c.fh_hash_file(self.ctx, c_path.ptr);
if (result == null) return error.HashFailed;
defer c.fh_free_string(result);
return try self.allocator.dupe(u8, std.mem.span(result));
}
/// Batch hash multiple files (amortizes CGo overhead)
pub fn hashBatch(self: *NativeHasher, paths: []const []const u8) ![][]const u8 {
// Convert paths to C string array
const c_paths = try self.allocator.alloc([*c]const u8, paths.len);
defer self.allocator.free(c_paths);
for (paths, 0..) |path, i| {
const c_path = try self.allocator.dupeZ(u8, path);
c_paths[i] = c_path.ptr;
// Note: we need to keep these alive until after fh_hash_batch
}
defer {
for (c_paths) |p| {
self.allocator.free(std.mem.span(p));
}
}
// Allocate results array
const results = try self.allocator.alloc([*c]u8, paths.len);
defer self.allocator.free(results);
// Call native batch hash
const ret = c.fh_hash_batch(self.ctx, c_paths.ptr, @intCast(paths.len), results.ptr);
if (ret != 0) return error.HashFailed;
// Convert results to Zig strings
var hashes = try self.allocator.alloc([]const u8, paths.len);
errdefer {
for (hashes) |h| self.allocator.free(h);
self.allocator.free(hashes);
}
for (results, 0..) |r, i| {
hashes[i] = try self.allocator.dupe(u8, std.mem.span(r));
c.fh_free_string(r);
}
return hashes;
}
/// Hash entire directory (combined hash)
pub fn hashDirectory(self: *NativeHasher, dir_path: []const u8) ![]const u8 {
const c_path = try self.allocator.dupeZ(u8, dir_path);
defer self.allocator.free(c_path);
const result = c.fh_hash_directory(self.ctx, c_path.ptr);
if (result == null) return error.HashFailed;
defer c.fh_free_string(result);
return try self.allocator.dupe(u8, std.mem.span(result));
}
/// Hash directory with batch output (individual file hashes)
pub fn hashDirectoryBatch(
self: *NativeHasher,
dir_path: []const u8,
max_results: u32,
) !struct { hashes: [][]const u8, paths: [][]const u8, count: u32 } {
const c_path = try self.allocator.dupeZ(u8, dir_path);
defer self.allocator.free(c_path);
// Allocate output arrays
const hashes = try self.allocator.alloc([*c]u8, max_results);
defer self.allocator.free(hashes);
const paths = try self.allocator.alloc([*c]u8, max_results);
defer self.allocator.free(paths);
var count: u32 = 0;
const ret = c.fh_hash_directory_batch(
self.ctx,
c_path.ptr,
hashes.ptr,
paths.ptr,
max_results,
&count,
);
if (ret != 0) return error.HashFailed;
// Convert to Zig arrays
var zig_hashes = try self.allocator.alloc([]const u8, count);
errdefer {
for (zig_hashes) |h| self.allocator.free(h);
self.allocator.free(zig_hashes);
}
var zig_paths = try self.allocator.alloc([]const u8, count);
errdefer {
for (zig_paths) |p| self.allocator.free(p);
self.allocator.free(zig_paths);
}
for (0..count) |i| {
zig_hashes[i] = try self.allocator.dupe(u8, std.mem.span(hashes[i]));
c.fh_free_string(hashes[i]);
zig_paths[i] = try self.allocator.dupe(u8, std.mem.span(paths[i]));
c.fh_free_string(paths[i]);
}
return .{
.hashes = zig_hashes,
.paths = zig_paths,
.count = count,
};
}
/// Check if SIMD SHA-256 is available
pub fn hasSimd(self: *NativeHasher) bool {
_ = self;
return c.fh_has_simd_sha256() != 0;
}
/// Get implementation info (SIMD type, etc.)
pub fn getImplInfo(self: *NativeHasher) []const u8 {
_ = self;
return std.mem.span(c.fh_get_simd_impl_name());
}
};
/// Convenience function: hash directory using native library
pub fn hashDirectoryNative(allocator: std.mem.Allocator, dir_path: []const u8) ![]const u8 {
var hasher = try NativeHasher.init(allocator, 0); // Auto-detect threads
defer hasher.deinit();
return try hasher.hashDirectory(dir_path);
}
/// Convenience function: batch hash files using native library
pub fn hashFilesNative(
allocator: std.mem.Allocator,
paths: []const []const u8,
) ![][]const u8 {
var hasher = try NativeHasher.init(allocator, 0);
defer hasher.deinit();
return try hasher.hashBatch(paths);
}
test "NativeHasher basic operations" {
const allocator = std.testing.allocator;
// Skip if native library not available
var hasher = NativeHasher.init(allocator, 1) catch |err| {
if (err == error.NativeInitFailed) {
std.debug.print("Native library not available, skipping test\n", .{});
return;
}
return err;
};
defer hasher.deinit();
// Check SIMD availability
const has_simd = hasher.hasSimd();
const impl_name = hasher.getImplInfo();
std.debug.print("SIMD: {any}, Impl: {s}\n", .{ has_simd, impl_name });
}

View file

@ -0,0 +1,231 @@
const std = @import("std");
const ignore = @import("ignore.zig");
/// Thread-safe work queue for parallel directory walking
const WorkQueue = struct {
items: std.ArrayList(WorkItem),
mutex: std.Thread.Mutex,
condition: std.Thread.Condition,
done: bool,
const WorkItem = struct {
path: []const u8,
depth: usize,
};
fn init(allocator: std.mem.Allocator) WorkQueue {
return .{
.items = std.ArrayList(WorkItem).init(allocator),
.mutex = .{},
.condition = .{},
.done = false,
};
}
fn deinit(self: *WorkQueue, allocator: std.mem.Allocator) void {
for (self.items.items) |item| {
allocator.free(item.path);
}
self.items.deinit();
}
fn push(self: *WorkQueue, path: []const u8, depth: usize, allocator: std.mem.Allocator) !void {
self.mutex.lock();
defer self.mutex.unlock();
try self.items.append(.{
.path = try allocator.dupe(u8, path),
.depth = depth,
});
self.condition.signal();
}
fn pop(self: *WorkQueue) ?WorkItem {
self.mutex.lock();
defer self.mutex.unlock();
while (self.items.items.len == 0 and !self.done) {
self.condition.wait(&self.mutex);
}
if (self.items.items.len == 0) return null;
return self.items.pop();
}
fn setDone(self: *WorkQueue) void {
self.mutex.lock();
defer self.mutex.unlock();
self.done = true;
self.condition.broadcast();
}
};
/// Result from parallel directory walk
const WalkResult = struct {
files: std.ArrayList([]const u8),
mutex: std.Thread.Mutex,
fn init(allocator: std.mem.Allocator) WalkResult {
return .{
.files = std.ArrayList([]const u8).init(allocator),
.mutex = .{},
};
}
fn deinit(self: *WalkResult, allocator: std.mem.Allocator) void {
for (self.files.items) |file| {
allocator.free(file);
}
self.files.deinit();
}
fn add(self: *WalkResult, path: []const u8, allocator: std.mem.Allocator) !void {
self.mutex.lock();
defer self.mutex.unlock();
try self.files.append(try allocator.dupe(u8, path));
}
};
/// Thread context for parallel walking
const ThreadContext = struct {
queue: *WorkQueue,
result: *WalkResult,
gitignore: *ignore.GitIgnore,
base_path: []const u8,
allocator: std.mem.Allocator,
max_depth: usize,
};
/// Worker thread function for parallel directory walking
fn walkWorker(ctx: *ThreadContext) void {
while (true) {
const item = ctx.queue.pop() orelse break;
defer ctx.allocator.free(item.path);
if (item.depth >= ctx.max_depth) continue;
walkDirectoryParallel(ctx, item.path, item.depth) catch |err| {
std.log.warn("Error walking {s}: {any}", .{ item.path, err });
};
}
}
/// Walk a single directory and add subdirectories to queue
fn walkDirectoryParallel(ctx: *ThreadContext, dir_path: []const u8, depth: usize) !void {
const full_path = if (std.mem.eql(u8, dir_path, "."))
ctx.base_path
else
try std.fs.path.join(ctx.allocator, &[_][]const u8{ ctx.base_path, dir_path });
defer if (!std.mem.eql(u8, dir_path, ".")) ctx.allocator.free(full_path);
var dir = std.fs.cwd().openDir(full_path, .{ .iterate = true }) catch |err| switch (err) {
error.AccessDenied => return,
error.FileNotFound => return,
else => return err,
};
defer dir.close();
var it = dir.iterate();
while (true) {
const entry = it.next() catch |err| switch (err) {
error.AccessDenied => continue,
else => return err,
} orelse break;
const entry_path = if (std.mem.eql(u8, dir_path, "."))
try std.fmt.allocPrint(ctx.allocator, "{s}", .{entry.name})
else
try std.fs.path.join(ctx.allocator, &[_][]const u8{ dir_path, entry.name });
defer ctx.allocator.free(entry_path);
// Check default ignores
if (ignore.matchesDefaultIgnore(entry_path)) continue;
// Check gitignore patterns
const is_dir = entry.kind == .directory;
if (ctx.gitignore.isIgnored(entry_path, is_dir)) continue;
if (is_dir) {
// Add subdirectory to work queue
ctx.queue.push(entry_path, depth + 1, ctx.allocator) catch |err| {
std.log.warn("Failed to queue {s}: {any}", .{ entry_path, err });
};
} else if (entry.kind == .file) {
// Add file to results
ctx.result.add(entry_path, ctx.allocator) catch |err| {
std.log.warn("Failed to add file {s}: {any}", .{ entry_path, err });
};
}
}
}
/// Parallel directory walker that uses multiple threads
pub fn parallelWalk(
allocator: std.mem.Allocator,
base_path: []const u8,
gitignore: *ignore.GitIgnore,
num_threads: usize,
) ![][]const u8 {
var queue = WorkQueue.init(allocator);
defer queue.deinit(allocator);
var result = WalkResult.init(allocator);
defer result.deinit(allocator);
// Start with base directory
try queue.push(".", 0, allocator);
// Create thread context
var ctx = ThreadContext{
.queue = &queue,
.result = &result,
.gitignore = gitignore,
.base_path = base_path,
.allocator = allocator,
.max_depth = 100, // Prevent infinite recursion
};
// Spawn worker threads
var threads = try allocator.alloc(std.Thread, num_threads);
defer allocator.free(threads);
for (0..num_threads) |i| {
threads[i] = try std.Thread.spawn(.{}, walkWorker, .{&ctx});
}
// Wait for all workers to complete
for (threads) |thread| {
thread.join();
}
// Sort results for deterministic ordering
std.sort.block([]const u8, result.files.items, {}, struct {
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
return std.mem.order(u8, a, b) == .lt;
}
}.lessThan);
// Transfer ownership to caller
const files = try allocator.alloc([]const u8, result.files.items.len);
@memcpy(files, result.files.items);
result.files.items.len = 0; // Prevent deinit from freeing
return files;
}
test "parallelWalk basic" {
const allocator = std.testing.allocator;
var gitignore = ignore.GitIgnore.init(allocator);
defer gitignore.deinit();
// Walk the current directory with 4 threads
const files = try parallelWalk(allocator, ".", &gitignore, 4);
defer {
for (files) |f| allocator.free(f);
allocator.free(files);
}
// Should find at least some files
try std.testing.expect(files.len > 0);
}

283
cli/src/utils/watch.zig Normal file
View file

@ -0,0 +1,283 @@
const std = @import("std");
const os = std.os;
const log = std.log;
/// File watcher using OS-native APIs (kqueue on macOS, inotify on Linux)
/// Zero third-party dependencies - uses only standard library OS bindings
pub const FileWatcher = struct {
allocator: std.mem.Allocator,
watched_paths: std.StringHashMap(void),
// Platform-specific handles
kqueue_fd: if (@import("builtin").target.os.tag == .macos) i32 else void,
inotify_fd: if (@import("builtin").target.os.tag == .linux) i32 else void,
// Debounce timer
last_event_time: i64,
debounce_ms: i64,
pub fn init(allocator: std.mem.Allocator, debounce_ms: i64) !FileWatcher {
const target = @import("builtin").target;
var watcher = FileWatcher{
.allocator = allocator,
.watched_paths = std.StringHashMap(void).init(allocator),
.kqueue_fd = if (target.os.tag == .macos) -1 else {},
.inotify_fd = if (target.os.tag == .linux) -1 else {},
.last_event_time = 0,
.debounce_ms = debounce_ms,
};
switch (target.os.tag) {
.macos => {
watcher.kqueue_fd = try os.kqueue();
},
.linux => {
watcher.inotify_fd = try std.os.inotify_init1(os.linux.IN_CLOEXEC);
},
else => {
return error.UnsupportedPlatform;
},
}
return watcher;
}
pub fn deinit(self: *FileWatcher) void {
const target = @import("builtin").target;
switch (target.os.tag) {
.macos => {
if (self.kqueue_fd != -1) {
os.close(self.kqueue_fd);
}
},
.linux => {
if (self.inotify_fd != -1) {
os.close(self.inotify_fd);
}
},
else => {},
}
var it = self.watched_paths.keyIterator();
while (it.next()) |key| {
self.allocator.free(key.*);
}
self.watched_paths.deinit();
}
/// Add a directory to watch recursively
pub fn watchDirectory(self: *FileWatcher, path: []const u8) !void {
if (self.watched_paths.contains(path)) return;
const path_copy = try self.allocator.dupe(u8, path);
try self.watched_paths.put(path_copy, {});
const target = @import("builtin").target;
switch (target.os.tag) {
.macos => try self.addKqueueWatch(path),
.linux => try self.addInotifyWatch(path),
else => return error.UnsupportedPlatform,
}
// Recursively watch subdirectories
var dir = std.fs.cwd().openDir(path, .{ .iterate = true }) catch |err| switch (err) {
error.AccessDenied => return,
error.FileNotFound => return,
else => return err,
};
defer dir.close();
var it = dir.iterate();
while (true) {
const entry = it.next() catch |err| switch (err) {
error.AccessDenied => continue,
else => return err,
} orelse break;
if (entry.kind == .directory) {
const subpath = try std.fs.path.join(self.allocator, &[_][]const u8{ path, entry.name });
defer self.allocator.free(subpath);
try self.watchDirectory(subpath);
}
}
}
/// Add kqueue watch for macOS
fn addKqueueWatch(self: *FileWatcher, path: []const u8) !void {
if (@import("builtin").target.os.tag != .macos) return;
const fd = try os.open(path, os.O.EVTONLY | os.O_RDONLY, 0);
defer os.close(fd);
const event = os.Kevent{
.ident = @intCast(fd),
.filter = os.EVFILT_VNODE,
.flags = os.EV_ADD | os.EV_CLEAR,
.fflags = os.NOTE_WRITE | os.NOTE_EXTEND | os.NOTE_ATTRIB | os.NOTE_LINK | os.NOTE_RENAME | os.NOTE_REVOKE,
.data = 0,
.udata = 0,
};
const changes = [_]os.Kevent{event};
_ = try os.kevent(self.kqueue_fd, &changes, &.{}, null);
}
/// Add inotify watch for Linux
fn addInotifyWatch(self: *FileWatcher, path: []const u8) !void {
if (@import("builtin").target.os.tag != .linux) return;
const mask = os.linux.IN_MODIFY | os.linux.IN_CREATE | os.linux.IN_DELETE |
os.linux.IN_MOVED_FROM | os.linux.IN_MOVED_TO | os.linux.IN_ATTRIB;
const wd = try os.linux.inotify_add_watch(self.inotify_fd, path.ptr, mask);
if (wd < 0) return error.InotifyError;
}
/// Wait for file changes with debouncing
pub fn waitForChanges(self: *FileWatcher, timeout_ms: i32) !bool {
const target = @import("builtin").target;
switch (target.os.tag) {
.macos => return try self.waitKqueue(timeout_ms),
.linux => return try self.waitInotify(timeout_ms),
else => return error.UnsupportedPlatform,
}
}
/// kqueue wait implementation
fn waitKqueue(self: *FileWatcher, timeout_ms: i32) !bool {
if (@import("builtin").target.os.tag != .macos) return false;
var ts: os.timespec = undefined;
if (timeout_ms >= 0) {
ts.tv_sec = @divTrunc(timeout_ms, 1000);
ts.tv_nsec = @mod(timeout_ms, 1000) * 1000000;
}
var events: [10]os.Kevent = undefined;
const nev = os.kevent(self.kqueue_fd, &.{}, &events, if (timeout_ms >= 0) &ts else null) catch |err| switch (err) {
error.ETIME => return false, // Timeout
else => return err,
};
if (nev > 0) {
const now = std.time.milliTimestamp();
if (now - self.last_event_time > self.debounce_ms) {
self.last_event_time = now;
return true;
}
}
return false;
}
/// inotify wait implementation
fn waitInotify(self: *FileWatcher, timeout_ms: i32) !bool {
if (@import("builtin").target.os.tag != .linux) return false;
var fds = [_]os.pollfd{.{
.fd = self.inotify_fd,
.events = os.POLLIN,
.revents = 0,
}};
const ready = os.poll(&fds, timeout_ms) catch |err| switch (err) {
error.ETIME => return false,
else => return err,
};
if (ready > 0 and (fds[0].revents & os.POLLIN) != 0) {
var buf: [4096]u8 align(@alignOf(os.linux.inotify_event)) = undefined;
const bytes_read = try os.read(self.inotify_fd, &buf);
if (bytes_read > 0) {
const now = std.time.milliTimestamp();
if (now - self.last_event_time > self.debounce_ms) {
self.last_event_time = now;
return true;
}
}
}
return false;
}
/// Run watch loop with callback
pub fn run(self: *FileWatcher, callback: fn () void) !void {
log.info("Watching for file changes (debounce: {d}ms)...", .{self.debounce_ms});
while (true) {
if (try self.waitForChanges(-1)) {
callback();
}
}
}
};
/// Watch command handler
pub fn watchCommand(allocator: std.mem.Allocator, args: []const []const u8) !void {
if (args.len < 1) {
std.debug.print("Usage: ml watch <path> [--sync] [--queue]\n", .{});
return error.InvalidArgs;
}
const path = args[0];
var auto_sync = false;
var auto_queue = false;
// Parse flags
for (args[1..]) |arg| {
if (std.mem.eql(u8, arg, "--sync")) auto_sync = true;
if (std.mem.eql(u8, arg, "--queue")) auto_queue = true;
}
var watcher = try FileWatcher.init(allocator, 100); // 100ms debounce
defer watcher.deinit();
try watcher.watchDirectory(path);
log.info("Watching {s} for changes...", .{path});
// Callback for file changes
const CallbackContext = struct {
allocator: std.mem.Allocator,
path: []const u8,
auto_sync: bool,
auto_queue: bool,
};
const ctx = CallbackContext{
.allocator = allocator,
.path = path,
.auto_sync = auto_sync,
.auto_queue = auto_queue,
};
// Run watch loop
while (true) {
if (try watcher.waitForChanges(-1)) {
log.info("File changes detected", .{});
if (auto_sync) {
log.info("Auto-syncing...", .{});
// Trigger sync (implementation would call sync command)
_ = ctx;
}
if (auto_queue) {
log.info("Auto-queuing...", .{});
// Trigger queue (implementation would call queue command)
}
}
}
}
test "FileWatcher init/deinit" {
const allocator = std.testing.allocator;
var watcher = try FileWatcher.init(allocator, 100);
defer watcher.deinit();
}

View file

@ -356,7 +356,8 @@ api_key = "your_api_key_here" # Your API key (get from admin)
// Set proper permissions
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
log.Printf("Warning: %v", err)
// Log permission warning but don't fail
_ = err
}
return nil

View file

@ -17,12 +17,14 @@ type Config struct {
SSHKey string `toml:"ssh_key"`
Port int `toml:"port"`
BasePath string `toml:"base_path"`
Mode string `toml:"mode"` // "dev" or "prod"
WrapperScript string `toml:"wrapper_script"`
TrainScript string `toml:"train_script"`
RedisAddr string `toml:"redis_addr"`
RedisPassword string `toml:"redis_password"`
RedisDB int `toml:"redis_db"`
KnownHosts string `toml:"known_hosts"`
ServerURL string `toml:"server_url"` // WebSocket server URL (e.g., ws://localhost:8080)
// Authentication
Auth auth.Config `toml:"auth"`
@ -133,6 +135,20 @@ func (c *Config) Validate() error {
}
}
// Set default mode if not specified
if c.Mode == "" {
if os.Getenv("FETCH_ML_TUI_MODE") != "" {
c.Mode = os.Getenv("FETCH_ML_TUI_MODE")
} else {
c.Mode = "dev" // Default to dev mode
}
}
// Set mode-appropriate default paths using project-relative paths
if c.BasePath == "" {
c.BasePath = utils.ModeBasedBasePath(c.Mode)
}
if c.BasePath != "" {
// Convert relative paths to absolute
c.BasePath = utils.ExpandPath(c.BasePath)

View file

@ -24,6 +24,7 @@ func (c *Controller) loadAllData() tea.Cmd {
c.loadQueue(),
c.loadGPU(),
c.loadContainer(),
c.loadDatasets(),
)
}
@ -39,6 +40,13 @@ func (c *Controller) loadJobs() tea.Cmd {
var jobs []model.Job
statusChan := make(chan []model.Job, 4)
// Debug: Print paths being used
c.logger.Info("Loading jobs from paths",
"pending", c.getPathForStatus(model.StatusPending),
"running", c.getPathForStatus(model.StatusRunning),
"finished", c.getPathForStatus(model.StatusFinished),
"failed", c.getPathForStatus(model.StatusFailed))
for _, status := range []model.JobStatus{
model.StatusPending,
model.StatusRunning,
@ -48,22 +56,18 @@ func (c *Controller) loadJobs() tea.Cmd {
go func(s model.JobStatus) {
path := c.getPathForStatus(s)
names := c.server.ListDir(path)
// Debug: Log what we found
c.logger.Info("Listed directory", "status", s, "path", path, "count", len(names))
var statusJobs []model.Job
for _, name := range names {
jobStatus, _ := c.taskQueue.GetJobStatus(name)
taskID := jobStatus["task_id"]
priority := int64(0)
if p, ok := jobStatus["priority"]; ok {
_, err := fmt.Sscanf(p, "%d", &priority)
if err != nil {
priority = 0
}
}
// Lazy loading: only fetch basic info for list view
// Full details (GPU, narrative) loaded on selection
statusJobs = append(statusJobs, model.Job{
Name: name,
Status: s,
TaskID: taskID,
Priority: priority,
Name: name,
Status: s,
// TaskID, Priority, GPU info loaded lazily
})
}
statusChan <- statusJobs
@ -85,6 +89,24 @@ func (c *Controller) loadJobs() tea.Cmd {
}
}
// loadJobDetails loads full details for a specific job (lazy loading)
func (c *Controller) loadJobDetails(jobName string) tea.Cmd {
return func() tea.Msg {
jobStatus, _ := c.taskQueue.GetJobStatus(jobName)
// Parse priority
priority := int64(0)
if p, ok := jobStatus["priority"]; ok {
fmt.Sscanf(p, "%d", &priority)
}
// Build full job with details
// This is called when job is selected for detailed view
return model.StatusMsg{Text: "Loaded details for " + jobName, Level: "info"}
}
}
func (c *Controller) loadQueue() tea.Cmd {
return func() tea.Msg {
tasks, err := c.taskQueue.GetQueuedTasks()
@ -362,6 +384,18 @@ func (c *Controller) showQueue(m model.State) tea.Cmd {
}
}
func (c *Controller) loadDatasets() tea.Cmd {
return func() tea.Msg {
datasets, err := c.taskQueue.ListDatasets()
if err != nil {
c.logger.Error("failed to load datasets", "error", err)
return model.StatusMsg{Text: "Failed to load datasets: " + err.Error(), Level: "error"}
}
c.logger.Info("loaded datasets", "count", len(datasets))
return model.DatasetsLoadedMsg(datasets)
}
}
func tickCmd() tea.Cmd {
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
return model.TickMsg(t)

View file

@ -2,6 +2,7 @@ package controller
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/key"
@ -19,6 +20,7 @@ type Controller struct {
server *services.MLServer
taskQueue *services.TaskQueue
logger *logging.Logger
wsClient *services.WebSocketClient
}
func (c *Controller) handleKeyMsg(msg tea.KeyMsg, m model.State) (model.State, tea.Cmd) {
@ -143,6 +145,33 @@ func (c *Controller) handleGlobalKeys(msg tea.KeyMsg, m *model.State) []tea.Cmd
case key.Matches(msg, m.Keys.ViewExperiments):
m.ActiveView = model.ViewModeExperiments
cmds = append(cmds, c.loadExperiments())
case key.Matches(msg, m.Keys.ViewNarrative):
m.ActiveView = model.ViewModeNarrative
if job := getSelectedJob(*m); job != nil {
m.SelectedJob = *job
}
case key.Matches(msg, m.Keys.ViewTeam):
m.ActiveView = model.ViewModeTeam
case key.Matches(msg, m.Keys.ViewExperimentHistory):
m.ActiveView = model.ViewModeExperimentHistory
cmds = append(cmds, c.loadExperimentHistory())
case key.Matches(msg, m.Keys.ViewConfig):
m.ActiveView = model.ViewModeConfig
cmds = append(cmds, c.loadConfig())
case key.Matches(msg, m.Keys.ViewLogs):
m.ActiveView = model.ViewModeLogs
if job := getSelectedJob(*m); job != nil {
cmds = append(cmds, c.loadLogs(job.Name))
}
case key.Matches(msg, m.Keys.ViewExport):
if job := getSelectedJob(*m); job != nil {
cmds = append(cmds, c.exportJob(job.Name))
}
case key.Matches(msg, m.Keys.FilterTeam):
m.InputMode = true
m.Input.SetValue("@")
m.Input.Focus()
m.Status = "Filter by team member: @alice, @bob, @team-ml"
case key.Matches(msg, m.Keys.Cancel):
if job := getSelectedJob(*m); job != nil && job.TaskID != "" {
cmds = append(cmds, c.cancelTask(job.TaskID))
@ -181,8 +210,18 @@ func (c *Controller) applyWindowSize(msg tea.WindowSizeMsg, m model.State) model
m.QueueView.Height = listHeight - 4
m.SettingsView.Width = panelWidth
m.SettingsView.Height = listHeight - 4
m.NarrativeView.Width = panelWidth
m.NarrativeView.Height = listHeight - 4
m.TeamView.Width = panelWidth
m.TeamView.Height = listHeight - 4
m.ExperimentsView.Width = panelWidth
m.ExperimentsView.Height = listHeight - 4
m.ExperimentHistoryView.Width = panelWidth
m.ExperimentHistoryView.Height = listHeight - 4
m.ConfigView.Width = panelWidth
m.ConfigView.Height = listHeight - 4
m.LogsView.Width = panelWidth
m.LogsView.Height = listHeight - 4
return m
}
@ -245,7 +284,25 @@ func (c *Controller) handleStatusMsg(msg model.StatusMsg, m model.State) (model.
func (c *Controller) handleTickMsg(msg model.TickMsg, m model.State) (model.State, tea.Cmd) {
var cmds []tea.Cmd
if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading {
// Calculate actual refresh rate
now := time.Now()
if !m.LastFrameTime.IsZero() {
elapsed := now.Sub(m.LastFrameTime).Milliseconds()
if elapsed > 0 {
// Smooth the rate with simple averaging
m.RefreshRate = (m.RefreshRate*float64(m.FrameCount) + float64(elapsed)) / float64(m.FrameCount+1)
m.FrameCount++
if m.FrameCount > 100 {
m.FrameCount = 1
m.RefreshRate = float64(elapsed)
}
}
}
m.LastFrameTime = now
// 500ms refresh target for real-time updates
if time.Since(m.LastRefresh) > 500*time.Millisecond && !m.IsLoading {
m.LastRefresh = time.Now()
cmds = append(cmds, c.loadAllData())
}
@ -290,16 +347,25 @@ func New(
tq *services.TaskQueue,
logger *logging.Logger,
) *Controller {
// Create WebSocket client for real-time updates
wsClient := services.NewWebSocketClient(cfg.ServerURL, "", logger)
return &Controller{
config: cfg,
server: srv,
taskQueue: tq,
logger: logger,
wsClient: wsClient,
}
}
// Init initializes the TUI and returns initial commands
func (c *Controller) Init() tea.Cmd {
// Connect WebSocket for real-time updates
if err := c.wsClient.Connect(); err != nil {
c.logger.Error("WebSocket connection failed", "error", err)
}
return tea.Batch(
tea.SetWindowTitle("FetchML"),
c.loadAllData(),
@ -307,14 +373,17 @@ func (c *Controller) Init() tea.Cmd {
)
}
// Update handles all messages and updates the state
func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
switch typed := msg.(type) {
case tea.KeyMsg:
return c.handleKeyMsg(typed, m)
case tea.WindowSizeMsg:
updated := c.applyWindowSize(typed, m)
return c.finalizeUpdate(msg, updated)
// Only apply window size on first render, then keep constant
if m.Width == 0 && m.Height == 0 {
updated := c.applyWindowSize(typed, m)
return c.finalizeUpdate(msg, updated)
}
return c.finalizeUpdate(msg, m)
case model.JobsLoadedMsg:
return c.handleJobsLoadedMsg(typed, m)
case model.TasksLoadedMsg:
@ -323,8 +392,26 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
return c.handleGPUContent(typed, m)
case model.ContainerLoadedMsg:
return c.handleContainerContent(typed, m)
case model.QueueLoadedMsg:
return c.handleQueueContent(typed, m)
case model.DatasetsLoadedMsg:
// Format datasets into view content
var content strings.Builder
content.WriteString("Available Datasets\n")
content.WriteString(strings.Repeat("═", 50) + "\n\n")
if len(typed) == 0 {
content.WriteString("📭 No datasets found\n\n")
content.WriteString("Datasets will appear here when available\n")
content.WriteString("in the data directory.")
} else {
for i, ds := range typed {
content.WriteString(fmt.Sprintf("%d. 📁 %s\n", i+1, ds.Name))
content.WriteString(fmt.Sprintf(" Location: %s\n", ds.Location))
content.WriteString(fmt.Sprintf(" Size: %d bytes\n", ds.SizeBytes))
content.WriteString(fmt.Sprintf(" Last Access: %s\n\n", ds.LastAccess.Format("2006-01-02 15:04")))
}
}
m.DatasetView.SetContent(content.String())
m.DatasetView.GotoTop()
return c.finalizeUpdate(msg, m)
case model.SettingsContentMsg:
m.SettingsView.SetContent(string(typed))
return c.finalizeUpdate(msg, m)
@ -332,12 +419,36 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
m.ExperimentsView.SetContent(string(typed))
m.ExperimentsView.GotoTop()
return c.finalizeUpdate(msg, m)
case ExperimentHistoryLoadedMsg:
m.ExperimentHistoryView.SetContent(string(typed))
m.ExperimentHistoryView.GotoTop()
return c.finalizeUpdate(msg, m)
case ConfigLoadedMsg:
m.ConfigView.SetContent(string(typed))
m.ConfigView.GotoTop()
return c.finalizeUpdate(msg, m)
case LogsLoadedMsg:
m.LogsView.SetContent(string(typed))
m.LogsView.GotoTop()
return c.finalizeUpdate(msg, m)
case model.SettingsUpdateMsg:
return c.finalizeUpdate(msg, m)
case model.StatusMsg:
return c.handleStatusMsg(typed, m)
case model.TickMsg:
return c.handleTickMsg(typed, m)
case model.JobUpdateMsg:
// Handle real-time job status updates from WebSocket
m.Status = fmt.Sprintf("Job %s: %s", typed.JobName, typed.Status)
// Refresh job list to show updated status
return m, c.loadAllData()
case model.GPUUpdateMsg:
// Throttle GPU updates to 1/second (humans can't perceive faster)
if time.Since(m.LastGPUUpdate) > 1*time.Second {
m.LastGPUUpdate = time.Now()
return c.finalizeUpdate(msg, m)
}
return m, nil
default:
return c.finalizeUpdate(msg, m)
}
@ -346,6 +457,12 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
// ExperimentsLoadedMsg is sent when experiments are loaded
type ExperimentsLoadedMsg string
// ExperimentHistoryLoadedMsg is sent when experiment history is loaded
type ExperimentHistoryLoadedMsg string
// ConfigLoadedMsg is sent when config is loaded
type ConfigLoadedMsg string
func (c *Controller) loadExperiments() tea.Cmd {
return func() tea.Msg {
commitIDs, err := c.taskQueue.ListExperiments()
@ -372,3 +489,92 @@ func (c *Controller) loadExperiments() tea.Cmd {
return ExperimentsLoadedMsg(output)
}
}
func (c *Controller) loadExperimentHistory() tea.Cmd {
return func() tea.Msg {
// Placeholder - will show experiment history with annotations
return ExperimentHistoryLoadedMsg("Experiment History & Annotations\n\n" +
"This view will show:\n" +
"- Previous experiment runs\n" +
"- Annotations and notes\n" +
"- Config snapshots\n" +
"- Side-by-side comparisons\n\n" +
"(Requires API: GET /api/experiments/:id/history)")
}
}
func (c *Controller) loadConfig() tea.Cmd {
return func() tea.Msg {
// Build config diff showing changes from defaults
var output strings.Builder
output.WriteString("⚙️ Config View (Read-Only)\n\n")
output.WriteString("┌─ Changes from Defaults ─────────────────────┐\n")
changes := []string{}
if c.config.Host != "" {
changes = append(changes, fmt.Sprintf("│ Host: %s", c.config.Host))
}
if c.config.Port != 0 && c.config.Port != 22 {
changes = append(changes, fmt.Sprintf("│ Port: %d (default: 22)", c.config.Port))
}
if c.config.BasePath != "" {
changes = append(changes, fmt.Sprintf("│ Base Path: %s", c.config.BasePath))
}
if c.config.RedisAddr != "" && c.config.RedisAddr != "localhost:6379" {
changes = append(changes, fmt.Sprintf("│ Redis: %s (default: localhost:6379)", c.config.RedisAddr))
}
if c.config.ServerURL != "" {
changes = append(changes, fmt.Sprintf("│ Server: %s", c.config.ServerURL))
}
if len(changes) == 0 {
output.WriteString("│ (Using all default settings)\n")
} else {
for _, change := range changes {
output.WriteString(change + "\n")
}
}
output.WriteString("└─────────────────────────────────────────────┘\n\n")
output.WriteString("Full Configuration:\n")
output.WriteString(fmt.Sprintf(" Host: %s\n", c.config.Host))
output.WriteString(fmt.Sprintf(" Port: %d\n", c.config.Port))
output.WriteString(fmt.Sprintf(" Base Path: %s\n", c.config.BasePath))
output.WriteString(fmt.Sprintf(" Redis: %s\n", c.config.RedisAddr))
output.WriteString(fmt.Sprintf(" Server: %s\n", c.config.ServerURL))
output.WriteString(fmt.Sprintf(" User: %s\n\n", c.config.User))
output.WriteString("Use CLI to modify: ml config set <key> <value>")
return ConfigLoadedMsg(output.String())
}
}
// LogsLoadedMsg is sent when logs are loaded
type LogsLoadedMsg string
func (c *Controller) loadLogs(jobName string) tea.Cmd {
return func() tea.Msg {
// Placeholder - will stream logs from job
return LogsLoadedMsg("📜 Logs for " + jobName + "\n\n" +
"Log streaming will appear here...\n\n" +
"(Requires API: GET /api/jobs/" + jobName + "/logs?follow=true)")
}
}
// ExportCompletedMsg is sent when export is complete
type ExportCompletedMsg struct {
JobName string
Path string
}
func (c *Controller) exportJob(jobName string) tea.Cmd {
return func() tea.Msg {
// Show export in progress
return model.StatusMsg{
Text: "Exporting " + jobName + "... (anonymized)",
Level: "info",
}
}
}

View file

@ -1,7 +1,11 @@
// Package model provides TUI data structures and state management
package model
import "fmt"
import (
"fmt"
"github.com/charmbracelet/lipgloss"
)
// JobStatus represents the status of a job
type JobStatus string
@ -21,12 +25,23 @@ type Job struct {
Status JobStatus
TaskID string
Priority int64
// Narrative fields for research context
Hypothesis string
Context string
Intent string
ExpectedOutcome string
ActualOutcome string
OutcomeStatus string // validated, invalidated, inconclusive
// GPU allocation tracking
GPUDeviceID int // -1 if not assigned
GPUUtilization int // 0-100%
GPUMemoryUsed int64 // MB
}
// Title returns the job title for display
func (j Job) Title() string { return j.Name }
// Description returns a formatted description with status icon
// Description returns a formatted description with status icon and GPU info
func (j Job) Description() string {
icon := map[JobStatus]string{
StatusPending: "⏸",
@ -39,7 +54,16 @@ func (j Job) Description() string {
if j.Priority > 0 {
pri = fmt.Sprintf(" [P%d]", j.Priority)
}
return fmt.Sprintf("%s %s%s", icon, j.Status, pri)
gpu := ""
if j.GPUDeviceID >= 0 {
gpu = fmt.Sprintf(" [GPU:%d %d%%]", j.GPUDeviceID, j.GPUUtilization)
}
// Apply status color to the status text
statusStyle := lipgloss.NewStyle().Foreground(StatusColor(j.Status))
coloredStatus := statusStyle.Render(string(j.Status))
return fmt.Sprintf("%s %s%s%s", icon, coloredStatus, pri, gpu)
}
// FilterValue returns the value used for filtering

View file

@ -5,42 +5,56 @@ import "github.com/charmbracelet/bubbles/key"
// KeyMap defines key bindings for the TUI
type KeyMap struct {
Refresh key.Binding
Trigger key.Binding
TriggerArgs key.Binding
ViewQueue key.Binding
ViewContainer key.Binding
ViewGPU key.Binding
ViewJobs key.Binding
ViewDatasets key.Binding
ViewExperiments key.Binding
ViewSettings key.Binding
Cancel key.Binding
Delete key.Binding
MarkFailed key.Binding
RefreshGPU key.Binding
Help key.Binding
Quit key.Binding
Refresh key.Binding
Trigger key.Binding
TriggerArgs key.Binding
ViewQueue key.Binding
ViewContainer key.Binding
ViewGPU key.Binding
ViewJobs key.Binding
ViewDatasets key.Binding
ViewExperiments key.Binding
ViewSettings key.Binding
ViewNarrative key.Binding
ViewTeam key.Binding
ViewExperimentHistory key.Binding
ViewConfig key.Binding
ViewLogs key.Binding
ViewExport key.Binding
FilterTeam key.Binding
Cancel key.Binding
Delete key.Binding
MarkFailed key.Binding
RefreshGPU key.Binding
Help key.Binding
Quit key.Binding
}
// DefaultKeys returns the default key bindings for the TUI
func DefaultKeys() KeyMap {
return KeyMap{
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")),
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")),
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")),
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
ViewQueue: key.NewBinding(key.WithKeys("q"), key.WithHelp("q", "view queue")),
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
ViewNarrative: key.NewBinding(key.WithKeys("n"), key.WithHelp("n", "narrative")),
ViewTeam: key.NewBinding(key.WithKeys("m"), key.WithHelp("m", "team")),
ViewExperimentHistory: key.NewBinding(key.WithKeys("e"), key.WithHelp("e", "experiment history")),
ViewConfig: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "config")),
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
ViewLogs: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "logs")),
ViewExport: key.NewBinding(key.WithKeys("E"), key.WithHelp("E", "export job")),
FilterTeam: key.NewBinding(key.WithKeys("@"), key.WithHelp("@", "filter by team")),
Cancel: key.NewBinding(key.WithKeys("x"), key.WithHelp("x", "cancel task")),
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
Quit: key.NewBinding(key.WithKeys("ctrl+c"), key.WithHelp("ctrl+c", "quit")),
}
}

View file

@ -9,6 +9,9 @@ type JobsLoadedMsg []Job
// TasksLoadedMsg contains loaded tasks from the queue
type TasksLoadedMsg []*Task
// DatasetsLoadedMsg contains loaded datasets
type DatasetsLoadedMsg []DatasetInfo
// GpuLoadedMsg contains GPU status information
type GpuLoadedMsg string

View file

@ -32,13 +32,18 @@ type ViewMode int
// ViewMode constants represent different TUI views
const (
ViewModeJobs ViewMode = iota // Jobs view mode
ViewModeGPU // GPU status view mode
ViewModeQueue // Queue status view mode
ViewModeContainer // Container status view mode
ViewModeSettings // Settings view mode
ViewModeDatasets // Datasets view mode
ViewModeExperiments // Experiments view mode
ViewModeJobs ViewMode = iota // Jobs view mode
ViewModeGPU // GPU status view mode
ViewModeQueue // Queue status view mode
ViewModeContainer // Container status view mode
ViewModeSettings // Settings view mode
ViewModeDatasets // Datasets view mode
ViewModeExperiments // Experiments view mode
ViewModeNarrative // Narrative/Outcome view mode
ViewModeTeam // Team collaboration view mode
ViewModeExperimentHistory // Experiment history view mode
ViewModeConfig // Config view mode
ViewModeLogs // Logs streaming view mode
)
// DatasetInfo represents dataset information in the TUI
@ -51,32 +56,42 @@ type DatasetInfo struct {
// State holds the application state
type State struct {
Jobs []Job
QueuedTasks []*Task
Datasets []DatasetInfo
JobList list.Model
GpuView viewport.Model
ContainerView viewport.Model
QueueView viewport.Model
SettingsView viewport.Model
DatasetView viewport.Model
ExperimentsView viewport.Model
Input textinput.Model
APIKeyInput textinput.Model
Status string
ErrorMsg string
InputMode bool
Width int
Height int
ShowHelp bool
Spinner spinner.Model
ActiveView ViewMode
LastRefresh time.Time
IsLoading bool
JobStats map[JobStatus]int
APIKey string
SettingsIndex int
Keys KeyMap
Jobs []Job
QueuedTasks []*Task
Datasets []DatasetInfo
JobList list.Model
GpuView viewport.Model
ContainerView viewport.Model
QueueView viewport.Model
SettingsView viewport.Model
DatasetView viewport.Model
ExperimentsView viewport.Model
NarrativeView viewport.Model
TeamView viewport.Model
ExperimentHistoryView viewport.Model
ConfigView viewport.Model
LogsView viewport.Model
SelectedJob Job
Input textinput.Model
APIKeyInput textinput.Model
Status string
ErrorMsg string
InputMode bool
Width int
Height int
ShowHelp bool
Spinner spinner.Model
ActiveView ViewMode
LastRefresh time.Time
LastFrameTime time.Time
RefreshRate float64 // measured in ms
FrameCount int
LastGPUUpdate time.Time
IsLoading bool
JobStats map[JobStatus]int
APIKey string
SettingsIndex int
Keys KeyMap
}
// InitialState creates the initial application state
@ -105,25 +120,54 @@ func InitialState(apiKey string) State {
s.Style = SpinnerStyle()
return State{
JobList: jobList,
GpuView: viewport.New(0, 0),
ContainerView: viewport.New(0, 0),
QueueView: viewport.New(0, 0),
SettingsView: viewport.New(0, 0),
DatasetView: viewport.New(0, 0),
ExperimentsView: viewport.New(0, 0),
Input: input,
APIKeyInput: apiKeyInput,
Status: "Connected",
InputMode: false,
ShowHelp: false,
Spinner: s,
ActiveView: ViewModeJobs,
LastRefresh: time.Now(),
IsLoading: false,
JobStats: make(map[JobStatus]int),
APIKey: apiKey,
SettingsIndex: 0,
Keys: DefaultKeys(),
JobList: jobList,
GpuView: viewport.New(0, 0),
ContainerView: viewport.New(0, 0),
QueueView: viewport.New(0, 0),
SettingsView: viewport.New(0, 0),
DatasetView: viewport.New(0, 0),
ExperimentsView: viewport.New(0, 0),
NarrativeView: viewport.New(0, 0),
TeamView: viewport.New(0, 0),
ExperimentHistoryView: viewport.New(0, 0),
ConfigView: viewport.New(0, 0),
LogsView: viewport.New(0, 0),
Input: input,
APIKeyInput: apiKeyInput,
Status: "Connected",
InputMode: false,
ShowHelp: false,
Spinner: s,
ActiveView: ViewModeJobs,
LastRefresh: time.Now(),
IsLoading: false,
JobStats: make(map[JobStatus]int),
APIKey: apiKey,
SettingsIndex: 0,
Keys: DefaultKeys(),
}
}
// LogMsg represents a log line from a job
type LogMsg struct {
JobName string `json:"job_name"`
Line string `json:"line"`
Time string `json:"time"`
}
// JobUpdateMsg represents a real-time job status update via WebSocket
type JobUpdateMsg struct {
JobName string `json:"job_name"`
Status string `json:"status"`
TaskID string `json:"task_id"`
Progress int `json:"progress"`
}
// GPUUpdateMsg represents a real-time GPU status update via WebSocket
type GPUUpdateMsg struct {
DeviceID int `json:"device_id"`
Utilization int `json:"utilization"`
MemoryUsed int64 `json:"memory_used"`
MemoryTotal int64 `json:"memory_total"`
Temperature int `json:"temperature"`
}

View file

@ -6,6 +6,38 @@ import (
"github.com/charmbracelet/lipgloss"
)
// Status colors for job list items
var (
// StatusRunningColor is green for running jobs
StatusRunningColor = lipgloss.Color("#2ecc71")
// StatusPendingColor is yellow for pending jobs
StatusPendingColor = lipgloss.Color("#f1c40f")
// StatusFailedColor is red for failed jobs
StatusFailedColor = lipgloss.Color("#e74c3c")
// StatusFinishedColor is blue for completed jobs
StatusFinishedColor = lipgloss.Color("#3498db")
// StatusQueuedColor is gray for queued jobs
StatusQueuedColor = lipgloss.Color("#95a5a6")
)
// StatusColor returns the color for a job status
func StatusColor(status JobStatus) lipgloss.Color {
switch status {
case StatusRunning:
return StatusRunningColor
case StatusPending:
return StatusPendingColor
case StatusFailed:
return StatusFailedColor
case StatusFinished:
return StatusFinishedColor
case StatusQueued:
return StatusQueuedColor
default:
return lipgloss.Color("#ffffff")
}
}
// NewJobListDelegate creates a styled delegate for the job list
func NewJobListDelegate() list.DefaultDelegate {
delegate := list.NewDefaultDelegate()

View file

@ -0,0 +1,46 @@
// Package services provides TUI service clients
package services
import (
"fmt"
"time"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// ExportService handles job export functionality for TUI
type ExportService struct {
serverURL string
apiKey string
logger *logging.Logger
}
// NewExportService creates a new export service
func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportService {
return &ExportService{
serverURL: serverURL,
apiKey: apiKey,
logger: logger,
}
}
// ExportJob exports a job with optional anonymization
// Returns the path to the exported file
func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) {
s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize)
// Placeholder - actual implementation would call API
// POST /api/jobs/{id}/export?anonymize=true
exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix())
s.logger.Info("export complete", "job", jobName, "path", exportPath)
return exportPath, nil
}
// ExportOptions contains options for export
type ExportOptions struct {
Anonymize bool
IncludeLogs bool
IncludeData bool
}

View file

@ -4,8 +4,12 @@ package services
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/domain"
"github.com/jfraeys/fetch_ml/internal/experiment"
"github.com/jfraeys/fetch_ml/internal/network"
@ -21,6 +25,7 @@ type TaskQueue struct {
*queue.TaskQueue // Embed to inherit all queue methods directly
expManager *experiment.Manager
ctx context.Context
config *config.Config
}
// NewTaskQueue creates a new task queue service
@ -37,14 +42,17 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
return nil, fmt.Errorf("failed to create task queue: %w", err)
}
// Initialize experiment manager
// TODO: Get base path from config
expManager := experiment.NewManager("./experiments")
// Initialize experiment manager with proper path
// BasePath already includes the mode-based experiments path (e.g., ./data/dev/experiments)
expDir := cfg.BasePath
os.MkdirAll(expDir, 0755)
expManager := experiment.NewManager(expDir)
return &TaskQueue{
TaskQueue: internalQueue,
expManager: expManager,
ctx: context.Background(),
config: cfg,
}, nil
}
@ -94,20 +102,38 @@ func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
return map[string]string{}, nil
}
// ListDatasets retrieves available datasets (TUI-specific: currently returns empty)
func (tq *TaskQueue) ListDatasets() ([]struct {
Name string
SizeBytes int64
Location string
LastAccess string
}, error) {
// This method doesn't exist in internal queue, return empty for now
return []struct {
Name string
SizeBytes int64
Location string
LastAccess string
}{}, nil
// ListDatasets retrieves available datasets from the filesystem
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
var datasets []model.DatasetInfo
// Scan the active data directory for datasets
dataDir := tq.config.BasePath
if dataDir == "" {
return datasets, nil
}
entries, err := os.ReadDir(dataDir)
if err != nil {
// Directory might not exist yet, return empty
return datasets, nil
}
for _, entry := range entries {
if entry.IsDir() {
info, err := entry.Info()
if err != nil {
continue
}
datasets = append(datasets, model.DatasetInfo{
Name: entry.Name(),
SizeBytes: info.Size(),
Location: filepath.Join(dataDir, entry.Name()),
LastAccess: time.Now(),
})
}
}
return datasets, nil
}
// ListExperiments retrieves experiment list

View file

@ -0,0 +1,275 @@
// Package services provides TUI service clients
package services
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
"github.com/jfraeys/fetch_ml/internal/logging"
)
// WebSocketClient manages real-time updates from the server
type WebSocketClient struct {
conn *websocket.Conn
serverURL string
apiKey string
logger *logging.Logger
// Channels for different update types
jobUpdates chan model.JobUpdateMsg
gpuUpdates chan model.GPUUpdateMsg
statusUpdates chan model.StatusMsg
// Control
ctx context.Context
cancel context.CancelFunc
connected bool
}
// JobUpdateMsg represents a real-time job status update
type JobUpdateMsg struct {
JobName string `json:"job_name"`
Status string `json:"status"`
TaskID string `json:"task_id"`
Progress int `json:"progress"`
}
// GPUUpdateMsg represents a real-time GPU status update
type GPUUpdateMsg struct {
DeviceID int `json:"device_id"`
Utilization int `json:"utilization"`
MemoryUsed int64 `json:"memory_used"`
MemoryTotal int64 `json:"memory_total"`
Temperature int `json:"temperature"`
}
// NewWebSocketClient creates a new WebSocket client
func NewWebSocketClient(serverURL, apiKey string, logger *logging.Logger) *WebSocketClient {
ctx, cancel := context.WithCancel(context.Background())
return &WebSocketClient{
serverURL: serverURL,
apiKey: apiKey,
logger: logger,
jobUpdates: make(chan model.JobUpdateMsg, 100),
gpuUpdates: make(chan model.GPUUpdateMsg, 100),
statusUpdates: make(chan model.StatusMsg, 100),
ctx: ctx,
cancel: cancel,
}
}
// Connect establishes the WebSocket connection
func (c *WebSocketClient) Connect() error {
// Parse server URL and construct WebSocket URL
u, err := url.Parse(c.serverURL)
if err != nil {
return fmt.Errorf("invalid server URL: %w", err)
}
// Convert http/https to ws/wss
wsScheme := "ws"
if u.Scheme == "https" {
wsScheme = "wss"
}
wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host)
// Create dialer with timeout
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
Subprotocols: []string{"fetchml-v1"},
}
// Add API key to headers
headers := http.Header{}
if c.apiKey != "" {
headers.Set("X-API-Key", c.apiKey)
}
conn, resp, err := dialer.Dial(wsURL, headers)
if err != nil {
if resp != nil {
return fmt.Errorf("websocket dial failed (status %d): %w", resp.StatusCode, err)
}
return fmt.Errorf("websocket dial failed: %w", err)
}
c.conn = conn
c.connected = true
c.logger.Info("websocket connected", "url", wsURL)
// Start message handler
go c.messageHandler()
// Start heartbeat
go c.heartbeat()
return nil
}
// Disconnect closes the WebSocket connection
func (c *WebSocketClient) Disconnect() {
c.cancel()
if c.conn != nil {
c.conn.Close()
}
c.connected = false
}
// IsConnected returns true if connected
func (c *WebSocketClient) IsConnected() bool {
return c.connected
}
// messageHandler reads messages from the WebSocket
func (c *WebSocketClient) messageHandler() {
for {
select {
case <-c.ctx.Done():
return
default:
}
if c.conn == nil {
time.Sleep(100 * time.Millisecond)
continue
}
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Read message
messageType, data, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
c.logger.Error("websocket read error", "error", err)
}
c.connected = false
// Attempt reconnect
time.Sleep(5 * time.Second)
if err := c.Connect(); err != nil {
c.logger.Error("websocket reconnect failed", "error", err)
}
continue
}
// Handle binary vs text messages
if messageType == websocket.BinaryMessage {
c.handleBinaryMessage(data)
} else {
c.handleTextMessage(data)
}
}
}
// handleBinaryMessage handles binary WebSocket messages
func (c *WebSocketClient) handleBinaryMessage(data []byte) {
if len(data) < 2 {
return
}
// Binary protocol: [opcode:1][data...]
opcode := data[0]
payload := data[1:]
switch opcode {
case 0x01: // Job update
var update JobUpdateMsg
if err := json.Unmarshal(payload, &update); err != nil {
c.logger.Error("failed to unmarshal job update", "error", err)
return
}
c.jobUpdates <- model.JobUpdateMsg(update)
case 0x02: // GPU update
var update GPUUpdateMsg
if err := json.Unmarshal(payload, &update); err != nil {
c.logger.Error("failed to unmarshal GPU update", "error", err)
return
}
c.gpuUpdates <- model.GPUUpdateMsg(update)
case 0x03: // Status message
var status model.StatusMsg
if err := json.Unmarshal(payload, &status); err != nil {
c.logger.Error("failed to unmarshal status", "error", err)
return
}
c.statusUpdates <- status
}
}
// handleTextMessage handles text WebSocket messages (JSON)
func (c *WebSocketClient) handleTextMessage(data []byte) {
var msg map[string]interface{}
if err := json.Unmarshal(data, &msg); err != nil {
c.logger.Error("failed to unmarshal text message", "error", err)
return
}
msgType, _ := msg["type"].(string)
switch msgType {
case "job_update":
// Handle JSON job updates
case "gpu_update":
// Handle JSON GPU updates
case "status":
// Handle status messages
}
}
// heartbeat sends periodic ping messages
func (c *WebSocketClient) heartbeat() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
if c.conn != nil {
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.logger.Error("websocket ping failed", "error", err)
c.connected = false
}
}
}
}
}
// Subscribe subscribes to specific update channels
func (c *WebSocketClient) Subscribe(channels ...string) error {
if !c.connected {
return fmt.Errorf("not connected")
}
subMsg := map[string]interface{}{
"action": "subscribe",
"channels": channels,
}
data, _ := json.Marshal(subMsg)
return c.conn.WriteMessage(websocket.TextMessage, data)
}
// GetJobUpdates returns the job updates channel
func (c *WebSocketClient) GetJobUpdates() <-chan model.JobUpdateMsg {
return c.jobUpdates
}
// GetGPUUpdates returns the GPU updates channel
func (c *WebSocketClient) GetGPUUpdates() <-chan model.GPUUpdateMsg {
return c.gpuUpdates
}
// GetStatusUpdates returns the status updates channel
func (c *WebSocketClient) GetStatusUpdates() <-chan model.StatusMsg {
return c.statusUpdates
}

View file

@ -0,0 +1,147 @@
// Package view provides TUI rendering functionality
package view
import (
"strings"
"github.com/charmbracelet/lipgloss"
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
)
// NarrativeView displays job research context and outcome
type NarrativeView struct {
Width int
Height int
}
// View renders the narrative/outcome for a job
func (v *NarrativeView) View(job model.Job) string {
if job.Hypothesis == "" && job.Context == "" && job.Intent == "" {
return v.renderEmpty()
}
var sections []string
// Title
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#ff9e64")).
MarginBottom(1)
sections = append(sections, titleStyle.Render("📊 Research Context"))
// Hypothesis
if job.Hypothesis != "" {
sections = append(sections, v.renderSection("🧪 Hypothesis", job.Hypothesis))
}
// Context
if job.Context != "" {
sections = append(sections, v.renderSection("📚 Context", job.Context))
}
// Intent
if job.Intent != "" {
sections = append(sections, v.renderSection("🎯 Intent", job.Intent))
}
// Expected Outcome
if job.ExpectedOutcome != "" {
sections = append(sections, v.renderSection("📈 Expected Outcome", job.ExpectedOutcome))
}
// Actual Outcome (if available)
if job.ActualOutcome != "" {
statusIcon := "❓"
switch job.OutcomeStatus {
case "validated":
statusIcon = "✅"
case "invalidated":
statusIcon = "❌"
case "inconclusive":
statusIcon = "⚖️"
case "partial":
statusIcon = "📝"
}
sections = append(sections, v.renderSection(statusIcon+" Actual Outcome", job.ActualOutcome))
}
// Combine all sections
content := lipgloss.JoinVertical(lipgloss.Left, sections...)
// Apply border and padding
style := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7aa2f7")).
Padding(1, 2).
Width(v.Width - 4)
return style.Render(content)
}
func (v *NarrativeView) renderSection(title, content string) string {
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7aa2f7"))
contentStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#d8dee9")).
MarginLeft(2)
// Wrap long content
wrapped := wrapText(content, v.Width-12)
return lipgloss.JoinVertical(lipgloss.Left,
titleStyle.Render(title),
contentStyle.Render(wrapped),
"",
)
}
func (v *NarrativeView) renderEmpty() string {
style := lipgloss.NewStyle().
Foreground(lipgloss.Color("#88c0d0")).
Italic(true).
Padding(2)
return style.Render("No narrative data available for this job.\n\n" +
"Use --hypothesis, --context, --intent flags when queuing jobs.")
}
// wrapText wraps text to maxWidth, preserving newlines
func wrapText(text string, maxWidth int) string {
if maxWidth <= 0 {
return text
}
lines := strings.Split(text, "\n")
var result []string
for _, line := range lines {
if len(line) <= maxWidth {
result = append(result, line)
continue
}
// Simple word wrap
words := strings.Fields(line)
var currentLine string
for _, word := range words {
if len(currentLine)+len(word)+1 > maxWidth {
if currentLine != "" {
result = append(result, currentLine)
}
currentLine = word
} else {
if currentLine != "" {
currentLine += " "
}
currentLine += word
}
}
if currentLine != "" {
result = append(result, currentLine)
}
}
return strings.Join(result, "\n")
}

View file

@ -2,6 +2,7 @@
package view
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
@ -180,6 +181,27 @@ func getRightPanel(m model.State, width int) string {
style = activeBorderStyle
viewTitle = "📦 Datasets"
content = m.DatasetView.View()
case model.ViewModeNarrative:
style = activeBorderStyle
viewTitle = "📊 Research Context"
narrativeView := &NarrativeView{Width: width, Height: m.Height}
content = narrativeView.View(m.SelectedJob)
case model.ViewModeTeam:
style = activeBorderStyle
viewTitle = "👥 Team Jobs"
content = "Team collaboration view - shows jobs from all team members\n\n(Requires API: GET /api/jobs?all_users=true)"
case model.ViewModeExperimentHistory:
style = activeBorderStyle
viewTitle = "📜 Experiment History"
content = m.ExperimentHistoryView.View()
case model.ViewModeConfig:
style = activeBorderStyle
viewTitle = "⚙️ Config"
content = m.ConfigView.View()
case model.ViewModeLogs:
style = activeBorderStyle
viewTitle = "📜 Logs"
content = m.LogsView.View()
default:
viewTitle = "📊 System Overview"
content = getOverviewPanel(m)
@ -218,7 +240,14 @@ func getStatusBar(m model.State) string {
if m.ShowHelp {
statusText = "Press 'h' to hide help"
}
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText)
// Add refresh rate indicator
refreshInfo := ""
if m.RefreshRate > 0 {
refreshInfo = fmt.Sprintf(" | %.0fms", m.RefreshRate)
}
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText + refreshInfo)
}
func helpText(m model.State) string {
@ -242,25 +271,28 @@ func helpText(m model.State) string {
Navigation
j/k, / : Move selection / : Filter jobs
1 : Job list view 2 : Datasets view
3 : Experiments view v : Queue view
3 : Experiments view q : Queue view
g : GPU view o : Container view
s : Settings view
s : Settings view n : Narrative view
m : Team view e : Experiment history
c : Config view l : Logs (job stream)
E : Export job @ : Filter by team
Actions
t : Queue job a : Queue w/ args
c : Cancel task d : Delete pending
x : Cancel task d : Delete pending
f : Mark as failed r : Refresh all
G : Refresh GPU only
General
h or ? : Toggle this help q/Ctrl+C : Quit
h or ? : Toggle this help Ctrl+C : Quit
`
}
func getQuickHelp(m model.State) string {
if m.ActiveView == model.ViewModeSettings {
return " ↑/↓:move enter:select esc:exit settings q:quit"
return " ↑/↓:move enter:select esc:exit settings Ctrl+C:quit"
}
return " h:help 1:jobs 2:datasets 3:experiments v:queue g:gpu o:containers " +
"s:settings t:queue r:refresh q:quit"
return " h:help 1:jobs 2:datasets 3:experiments q:queue g:gpu o:containers " +
"n:narrative m:team e:history c:config l:logs E:export @:filter s:settings t:trigger r:refresh Ctrl+C:quit"
}

View file

@ -5,6 +5,7 @@ import (
"log"
"os"
"os/signal"
"path/filepath"
"syscall"
tea "github.com/charmbracelet/bubbletea"
@ -41,6 +42,16 @@ func (m AppModel) View() string {
}
func main() {
// Redirect logs to file to prevent TUI disruption
homeDir, _ := os.UserHomeDir()
logDir := filepath.Join(homeDir, ".ml", "logs")
os.MkdirAll(logDir, 0755)
logFile, logErr := os.OpenFile(filepath.Join(logDir, "tui.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if logErr == nil {
log.SetOutput(logFile)
defer logFile.Close()
}
// Parse authentication flags
authFlags := auth.ParseAuthFlags()
if err := auth.ValidateFlags(authFlags); err != nil {
@ -85,10 +96,10 @@ func main() {
log.Printf(" 4. Run TUI: ./bin/tui")
log.Printf("")
log.Printf("Example ~/.ml/config.toml:")
log.Printf(" worker_host = \"localhost\"")
log.Printf(" worker_user = \"your_username\"")
log.Printf(" worker_base = \"~/ml_jobs\"")
log.Printf(" worker_port = 22")
log.Printf(" mode = \"dev\"")
log.Printf(" # Paths auto-resolve based on mode:")
log.Printf(" # dev mode: ./data/dev/experiments")
log.Printf(" # prod mode: ./data/prod/experiments")
log.Printf(" api_key = \"your_api_key_here\"")
log.Printf("")
log.Printf("For more help, see: https://github.com/jfraeys/fetch_ml/docs")
@ -101,6 +112,11 @@ func main() {
}
log.Printf("Loaded TOML configuration from %s", cliConfPath)
// Force local mode - TUI runs on server with direct filesystem access
cfg.Host = ""
// Clear BasePath to force mode-based path resolution
cfg.BasePath = ""
// Validate authentication configuration
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
log.Fatalf("Invalid authentication configuration: %v", err)
@ -130,7 +146,7 @@ func main() {
srv, err := services.NewMLServer(cfg)
if err != nil {
log.Fatalf("Failed to connect to server: %v", err)
log.Fatalf("Failed to initialize local server: %v", err)
}
defer func() {
if err := srv.Close(); err != nil {
@ -138,19 +154,24 @@ func main() {
}
}()
// TaskQueue is optional for local mode
tq, err := services.NewTaskQueue(cfg)
if err != nil {
log.Printf("Failed to connect to Redis: %v", err)
return
log.Printf("Warning: Failed to connect to Redis: %v", err)
log.Printf("Continuing without task queue functionality")
tq = nil
}
if tq != nil {
defer func() {
if err := tq.Close(); err != nil {
log.Printf("task queue close error: %v", err)
}
}()
}
defer func() {
if err := tq.Close(); err != nil {
log.Printf("task queue close error: %v", err)
}
}()
// Initialize logger with error level and no debug output
logger := logging.NewLogger(-4, false) // -4 = slog.LevelError
// Initialize logger with file output only (prevents TUI disruption)
logFilePath := filepath.Join(logDir, "tui.log")
logger := logging.NewFileLogger(-4, false, logFilePath) // -4 = slog.LevelError
// Initialize State and Controller
var effectiveAPIKey string
@ -177,14 +198,17 @@ func main() {
go func() {
<-sigChan
logger.Info("Received shutdown signal, closing TUI...")
p.Quit()
}()
if _, err := p.Run(); err != nil {
_ = p.ReleaseTerminal()
log.Printf("Error running TUI: %v", err)
logger.Error("Error running TUI", "error", err)
return
}
// Ensure terminal is released and resources are closed via defer statements
_ = p.ReleaseTerminal()
logger.Info("TUI shutdown complete")
}

View file

@ -19,7 +19,7 @@ security:
api_key_rotation_days: 90
audit_logging:
enabled: true
log_path: "/tmp/fetchml-audit.log"
log_path: "data/dev/logs/fetchml-audit.log"
rate_limit:
enabled: false
requests_per_minute: 60
@ -46,8 +46,8 @@ database:
logging:
level: "info"
file: ""
audit_log: ""
file: "data/dev/logs/fetchml.log"
audit_log: "data/dev/logs/fetchml-audit.log"
resources:
max_workers: 1

View file

@ -1,6 +1,6 @@
base_path: "./data/experiments"
base_path: "./data/dev/experiments"
data_dir: "./data/active"
data_dir: "./data/dev/active"
auth:
enabled: false
@ -19,7 +19,7 @@ security:
api_key_rotation_days: 90
audit_logging:
enabled: true
log_path: "./data/fetchml-audit.log"
log_path: "./data/dev/logs/fetchml-audit.log"
rate_limit:
enabled: false
requests_per_minute: 60
@ -42,12 +42,12 @@ redis:
database:
type: "sqlite"
connection: "./data/fetchml.sqlite"
connection: "./data/dev/fetchml.sqlite"
logging:
level: "info"
file: ""
audit_log: ""
file: "./data/dev/logs/fetchml.log"
audit_log: "./data/dev/logs/fetchml-audit.log"
resources:
max_workers: 1

View file

@ -1,6 +1,6 @@
base_path: "/app/data/experiments"
base_path: "/app/data/prod/experiments"
data_dir: "/data/active"
data_dir: "/app/data/prod/active"
auth:
enabled: true
@ -45,12 +45,12 @@ redis:
database:
type: "sqlite"
connection: "/app/data/experiments/fetch_ml.sqlite"
connection: "/app/data/prod/fetch_ml.sqlite"
logging:
level: "info"
file: "/logs/fetch_ml.log"
audit_log: ""
file: "/app/data/prod/logs/fetch_ml.log"
audit_log: "/app/data/prod/logs/audit.log"
resources:
max_workers: 2

View file

@ -36,6 +36,22 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
// Register HTTP API handlers
s.handlers.RegisterHandlers(mux)
// Register new REST API endpoints for TUI
jobsHandler := jobs.NewHandler(
s.expManager,
s.logger,
s.taskQueue,
s.db,
s.config.BuildAuthConfig(),
nil,
)
// Experiment history endpoint: GET /api/experiments/:id/history
mux.HandleFunc("GET /api/experiments/{id}/history", jobsHandler.GetExperimentHistoryHTTP)
// Team jobs endpoint: GET /api/jobs?all_users=true
mux.HandleFunc("GET /api/jobs", jobsHandler.ListAllJobsHTTP)
}
// registerHealthRoutes sets up health check endpoints

View file

@ -12,6 +12,8 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/jfraeys/fetch_ml/internal/audit"
@ -110,6 +112,22 @@ const (
PermJupyterRead = "jupyter:read"
)
// ClientType represents the type of WebSocket client
type ClientType int
const (
ClientTypeCLI ClientType = iota
ClientTypeTUI
)
// Client represents a connected WebSocket client
type Client struct {
conn *websocket.Conn
Type ClientType
User string
RemoteAddr string
}
// Handler provides WebSocket handling
type Handler struct {
authConfig *auth.Config
@ -125,6 +143,10 @@ type Handler struct {
jobsHandler *jobs.Handler
jupyterHandler *jupyterj.Handler
datasetsHandler *datasets.Handler
// Client management for push updates
clients map[*Client]bool
clientsMu sync.RWMutex
}
// NewHandler creates a new WebSocket handler
@ -158,6 +180,7 @@ func NewHandler(
jobsHandler: jobsHandler,
jupyterHandler: jupyterHandler,
datasetsHandler: datasetsHandler,
clients: make(map[*Client]bool),
}
}
@ -217,6 +240,25 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleConnection(conn *websocket.Conn) {
h.logger.Info("websocket connection established", "remote", conn.RemoteAddr())
// Register client
client := &Client{
conn: conn,
Type: ClientTypeTUI, // Assume TUI for now, could detect from handshake
User: "tui-user",
RemoteAddr: conn.RemoteAddr().String(),
}
h.clientsMu.Lock()
h.clients[client] = true
h.clientsMu.Unlock()
defer func() {
h.clientsMu.Lock()
delete(h.clients, client)
h.clientsMu.Unlock()
conn.Close()
}()
for {
messageType, payload, err := conn.ReadMessage()
if err != nil {
@ -765,3 +807,66 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
"message": "Outcome updated",
})
}
// BroadcastJobUpdate sends job status update to all connected TUI clients
func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
msg := map[string]any{
"type": "job_update",
"job_name": jobName,
"status": status,
"progress": progress,
"time": time.Now().Unix(),
}
payload, _ := json.Marshal(msg)
for client := range h.clients {
if client.Type == ClientTypeTUI {
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
h.logger.Warn("failed to broadcast to client", "error", err, "client", client.RemoteAddr)
}
}
}
}
// BroadcastGPUUpdate sends GPU status update to all connected TUI clients
func (h *Handler) BroadcastGPUUpdate(deviceID, utilization int, memoryUsed, memoryTotal int64) {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
msg := map[string]any{
"type": "gpu_update",
"device_id": deviceID,
"utilization": utilization,
"memory_used": memoryUsed,
"memory_total": memoryTotal,
"time": time.Now().Unix(),
}
payload, _ := json.Marshal(msg)
for client := range h.clients {
if client.Type == ClientTypeTUI {
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
h.logger.Warn("failed to broadcast GPU update", "error", err, "client", client.RemoteAddr)
}
}
}
}
// GetConnectedClientCount returns the number of connected TUI clients
func (h *Handler) GetConnectedClientCount() int {
h.clientsMu.RLock()
defer h.clientsMu.RUnlock()
count := 0
for client := range h.clients {
if client.Type == ClientTypeTUI {
count++
}
}
return count
}

View file

@ -229,6 +229,37 @@ func (s *SmartDefaults) ExpandPath(path string) string {
return path
}
// ModeBasedPath returns the appropriate path for a given mode
func ModeBasedPath(mode string, subpath string) string {
switch mode {
case "dev":
return filepath.Join("./data/dev", subpath)
case "prod":
return filepath.Join("./data/prod", subpath)
case "ci":
return filepath.Join("./data/ci", subpath)
case "prod-smoke":
return filepath.Join("./data/prod-smoke", subpath)
default:
return filepath.Join("./data/dev", subpath)
}
}
// ModeBasedBasePath returns the experiments base path for a given mode
func ModeBasedBasePath(mode string) string {
return ModeBasedPath(mode, "experiments")
}
// ModeBasedDataDir returns the data directory for a given mode
func ModeBasedDataDir(mode string) string {
return ModeBasedPath(mode, "active")
}
// ModeBasedLogDir returns the logs directory for a given mode
func ModeBasedLogDir(mode string) string {
return ModeBasedPath(mode, "logs")
}
// GetEnvironmentDescription returns a human-readable description
func (s *SmartDefaults) GetEnvironmentDescription() string {
switch s.Profile {