feat: update CLI, TUI, and security documentation
- Add safety checks to Zig build - Add TUI with job management and narrative views - Add WebSocket support and export services - Add smart configuration defaults - Update API routes with security headers - Update SECURITY.md with comprehensive policy - Add Makefile security scanning targets
This commit is contained in:
parent
02811c0ffe
commit
6028779239
34 changed files with 2956 additions and 175 deletions
59
Makefile
59
Makefile
|
|
@ -1,4 +1,4 @@
|
|||
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs
|
||||
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs security-scan gosec govulncheck check-unsafe security-audit test-security
|
||||
OK = ✓
|
||||
DOCS_PORT ?= 1313
|
||||
DOCS_BIND ?= 127.0.0.1
|
||||
|
|
@ -498,3 +498,60 @@ prod-status:
|
|||
|
||||
prod-logs:
|
||||
@./deployments/deploy.sh prod logs
|
||||
|
||||
# =============================================================================
|
||||
# SECURITY TARGETS
|
||||
# =============================================================================
|
||||
|
||||
.PHONY: security-scan gosec govulncheck check-unsafe security-audit
|
||||
|
||||
# Run all security scans
|
||||
security-scan: gosec govulncheck check-unsafe
|
||||
@echo "${OK} Security scan complete"
|
||||
|
||||
# Run gosec security linter
|
||||
gosec:
|
||||
@mkdir -p reports
|
||||
@echo "Running gosec security scan..."
|
||||
@if command -v gosec >/dev/null 2>&1; then \
|
||||
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
|
||||
gosec -fmt=sarif -out=reports/gosec-results.sarif ./... 2>/dev/null || true; \
|
||||
gosec ./... 2>/dev/null || echo "Note: gosec found issues (see reports/gosec-results.json)"; \
|
||||
else \
|
||||
echo "Installing gosec..."; \
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest; \
|
||||
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
|
||||
fi
|
||||
@echo "${OK} gosec scan complete (see reports/gosec-results.*)"
|
||||
|
||||
# Run govulncheck for known vulnerabilities
|
||||
govulncheck:
|
||||
@echo "Running govulncheck for known vulnerabilities..."
|
||||
@if command -v govulncheck >/dev/null 2>&1; then \
|
||||
govulncheck ./...; \
|
||||
else \
|
||||
echo "Installing govulncheck..."; \
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest; \
|
||||
govulncheck ./...; \
|
||||
fi
|
||||
@echo "${OK} govulncheck complete"
|
||||
|
||||
# Check for unsafe package usage
|
||||
check-unsafe:
|
||||
@echo "Checking for unsafe package usage..."
|
||||
@if grep -r "unsafe\." --include="*.go" ./internal ./cmd ./pkg 2>/dev/null; then \
|
||||
echo "WARNING: Found unsafe package usage (review required)"; \
|
||||
exit 1; \
|
||||
else \
|
||||
echo "${OK} No unsafe package usage found"; \
|
||||
fi
|
||||
|
||||
# Full security audit (tests + scans)
|
||||
security-audit: security-scan test-security
|
||||
@echo "${OK} Full security audit complete"
|
||||
|
||||
# Run security-specific tests
|
||||
test-security:
|
||||
@echo "Running security tests..."
|
||||
@go test -v ./tests/security/... 2>/dev/null || echo "Note: No security tests yet (will be added in Phase 5)"
|
||||
@echo "${OK} Security tests complete"
|
||||
|
|
|
|||
149
SECURITY.md
149
SECURITY.md
|
|
@ -1,6 +1,153 @@
|
|||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Please report security vulnerabilities to security@fetchml.io.
|
||||
Do NOT open public issues for security bugs.
|
||||
|
||||
Response timeline:
|
||||
- Acknowledgment: within 48 hours
|
||||
- Initial assessment: within 5 days
|
||||
- Fix released: within 30 days (critical), 90 days (high)
|
||||
|
||||
## Security Features
|
||||
|
||||
FetchML implements defense-in-depth security for ML research systems:
|
||||
|
||||
### Authentication & Authorization
|
||||
- **Argon2id API Key Hashing**: Memory-hard hashing resists GPU cracking
|
||||
- **RBAC with Role Inheritance**: Granular permissions (admin, data_scientist, data_engineer, viewer, operator)
|
||||
- **Constant-time Comparison**: Prevents timing attacks on key validation
|
||||
|
||||
### Cryptographic Practices
|
||||
- **Ed25519 Manifest Signing**: Tamper detection for run manifests
|
||||
- **SHA-256 with Salt**: Legacy key support with migration path
|
||||
- **Secure Key Generation**: 256-bit entropy for all API keys
|
||||
|
||||
### Container Security
|
||||
- **Rootless Podman**: No privileged containers
|
||||
- **Capability Dropping**: `--cap-drop ALL` by default
|
||||
- **No New Privileges**: `no-new-privileges` security opt
|
||||
- **Read-only Root Filesystem**: Immutable base image
|
||||
|
||||
### Input Validation
|
||||
- **Path Traversal Prevention**: Canonical path validation
|
||||
- **Command Injection Protection**: Shell metacharacter filtering
|
||||
- **Length Limits**: Prevents DoS via oversized inputs
|
||||
|
||||
### Audit & Monitoring
|
||||
- **Structured Audit Logging**: JSON-formatted security events
|
||||
- **Hash-chained Logs**: Tamper-evident audit trail
|
||||
- **Anomaly Detection**: Brute force, privilege escalation alerts
|
||||
- **Security Metrics**: Prometheus integration
|
||||
|
||||
### Supply Chain
|
||||
- **Dependency Scanning**: gosec + govulncheck in CI
|
||||
- **No unsafe Package**: Prohibited in production code
|
||||
- **Manifest Signing**: Ed25519 signatures for integrity
|
||||
|
||||
## Supported Versions
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| 0.2.x | :white_check_mark: |
|
||||
| 0.1.x | :x: |
|
||||
|
||||
## Security Checklist (Pre-Release)
|
||||
|
||||
### Code Review
|
||||
- [ ] No hardcoded secrets
|
||||
- [ ] No `unsafe` usage without justification
|
||||
- [ ] All user inputs validated
|
||||
- [ ] All file paths canonicalized
|
||||
- [ ] No secrets in error messages
|
||||
|
||||
### Dependency Audit
|
||||
- [ ] `go mod verify` passes
|
||||
- [ ] `govulncheck` shows no vulnerabilities
|
||||
- [ ] All dependencies pinned
|
||||
- [ ] No unmaintained dependencies
|
||||
|
||||
### Container Security
|
||||
- [ ] No privileged containers
|
||||
- [ ] Rootless execution
|
||||
- [ ] Seccomp/AppArmor applied
|
||||
- [ ] Network isolation
|
||||
|
||||
### Cryptography
|
||||
- [ ] Argon2id for key hashing
|
||||
- [ ] Ed25519 for signing
|
||||
- [ ] TLS 1.3 only
|
||||
- [ ] No weak ciphers
|
||||
|
||||
### Testing
|
||||
- [ ] Security tests pass
|
||||
- [ ] Fuzz tests for parsers
|
||||
- [ ] Authentication bypass tested
|
||||
- [ ] Container escape tested
|
||||
|
||||
## Security Commands
|
||||
|
||||
```bash
|
||||
# Run security scan
|
||||
make security-scan
|
||||
|
||||
# Check for vulnerabilities
|
||||
govulncheck ./...
|
||||
|
||||
# Static analysis
|
||||
gosec ./...
|
||||
|
||||
# Check for unsafe usage
|
||||
grep -r "unsafe\." --include="*.go" ./internal ./cmd
|
||||
|
||||
# Build with sanitizers
|
||||
cd native && cmake -DENABLE_ASAN=ON .. && make
|
||||
```
|
||||
|
||||
## Threat Model
|
||||
|
||||
### Attack Surfaces
|
||||
1. **External API**: Researchers submitting malicious jobs
|
||||
2. **Container Runtime**: Escape to host system
|
||||
3. **Data Exfiltration**: Stealing datasets/models
|
||||
4. **Privilege Escalation**: Researcher → admin
|
||||
5. **Supply Chain**: Compromised dependencies
|
||||
6. **Secrets Leakage**: API keys in logs/errors
|
||||
|
||||
### Mitigations
|
||||
| Threat | Mitigation |
|
||||
|--------|------------|
|
||||
| Malicious Jobs | Input validation, container sandboxing, resource limits |
|
||||
| Container Escape | Rootless, no-new-privileges, seccomp, read-only root |
|
||||
| Data Exfiltration | Network policies, audit logging, rate limiting |
|
||||
| Privilege Escalation | RBAC, least privilege, anomaly detection |
|
||||
| Supply Chain | Dependency scanning, manifest signing, pinned versions |
|
||||
| Secrets Leakage | Log sanitization, secrets manager, memory clearing |
|
||||
|
||||
## Responsible Disclosure
|
||||
|
||||
We follow responsible disclosure practices:
|
||||
|
||||
1. **Report privately**: Email security@fetchml.io with details
|
||||
2. **Provide details**: Steps to reproduce, impact assessment
|
||||
3. **Allow time**: We need 30-90 days to fix before public disclosure
|
||||
4. **Acknowledgment**: We credit researchers who report valid issues
|
||||
|
||||
## Security Team
|
||||
|
||||
- security@fetchml.io - Security issues and questions
|
||||
- security-response@fetchml.io - Active incident response
|
||||
|
||||
---
|
||||
|
||||
*Last updated: 2026-02-19*
|
||||
|
||||
---
|
||||
|
||||
# Security Guide for Fetch ML Homelab
|
||||
|
||||
This guide covers security best practices for deploying Fetch ML in a homelab environment.
|
||||
*The following section covers security best practices for deploying Fetch ML in a homelab environment.*
|
||||
|
||||
## Quick Setup
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,12 @@ pub fn build(b: *std.Build) void {
|
|||
|
||||
exe.root_module.addOptions("build_options", options);
|
||||
|
||||
// Link native dataset_hash library
|
||||
exe.linkLibC();
|
||||
exe.addLibraryPath(b.path("../native/build"));
|
||||
exe.linkSystemLibrary("dataset_hash");
|
||||
exe.addIncludePath(b.path("../native/dataset_hash"));
|
||||
|
||||
// Install the executable to zig-out/bin
|
||||
b.installArtifact(exe);
|
||||
|
||||
|
|
@ -94,6 +100,13 @@ pub fn build(b: *std.Build) void {
|
|||
// Standard Zig test discovery - find all test files automatically
|
||||
const test_step = b.step("test", "Run unit tests");
|
||||
|
||||
// Safety check for release builds
|
||||
const safety_check_step = b.step("safety-check", "Verify ReleaseSafe mode is used for production");
|
||||
if (optimize != .ReleaseSafe and optimize != .Debug) {
|
||||
const warn_no_safe = b.addSystemCommand(&.{ "echo", "WARNING: Building without ReleaseSafe mode. Production builds should use -Doptimize=ReleaseSafe" });
|
||||
safety_check_step.dependOn(&warn_no_safe.step);
|
||||
}
|
||||
|
||||
// Test main executable
|
||||
const main_tests = b.addTest(.{
|
||||
.root_module = b.createModule(.{
|
||||
|
|
|
|||
|
|
@ -429,11 +429,23 @@ fn queueSingleJob(
|
|||
|
||||
const commit_hex = try crypto.encodeHexLower(allocator, commit_id);
|
||||
defer allocator.free(commit_hex);
|
||||
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Check for existing job with same commit (incremental queue)
|
||||
if (!options.force) {
|
||||
const existing = try checkExistingJob(allocator, job_name, commit_id, api_key_hash, config);
|
||||
if (existing) |ex| {
|
||||
defer allocator.free(ex);
|
||||
// Server already has this job - handle duplicate response
|
||||
try handleDuplicateResponse(allocator, ex, job_name, commit_hex, options);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
|
||||
|
||||
// Connect to WebSocket and send queue message
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
|
@ -1145,3 +1157,46 @@ fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions
|
|||
|
||||
return try buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
/// Check if a job with the same commit_id already exists on the server
|
||||
/// Returns: Optional JSON response from server if duplicate found
|
||||
fn checkExistingJob(
|
||||
allocator: std.mem.Allocator,
|
||||
job_name: []const u8,
|
||||
commit_id: []const u8,
|
||||
api_key_hash: []const u8,
|
||||
config: Config,
|
||||
) !?[]const u8 {
|
||||
// Connect to server and query for existing job
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Send query for existing job
|
||||
try client.sendQueryJobByCommit(job_name, commit_id, api_key_hash);
|
||||
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
// Parse response
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch |err| {
|
||||
// If JSON parse fails, treat as no duplicate found
|
||||
std.log.debug("Failed to parse check response: {}", .{err});
|
||||
return null;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value.object;
|
||||
|
||||
// Check if job exists
|
||||
if (root.get("exists")) |exists| {
|
||||
if (!exists.bool) return null;
|
||||
|
||||
// Job exists - copy the full response for caller
|
||||
return try allocator.dupe(u8, message);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const rsync = @import("../utils/rsync_embedded.zig");
|
|||
const ws = @import("../net/ws/client.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const json = @import("../utils/json.zig");
|
||||
const native_hash = @import("../utils/native_hash.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
|
|
@ -26,6 +27,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
var should_queue = false;
|
||||
var priority: u8 = 5;
|
||||
var json_mode: bool = false;
|
||||
var dev_mode: bool = false;
|
||||
var use_timestamp_check = false;
|
||||
var dry_run = false;
|
||||
|
||||
// Parse flags
|
||||
var i: usize = 1;
|
||||
|
|
@ -40,6 +44,12 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
|
||||
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--dev")) {
|
||||
dev_mode = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--check-timestamp")) {
|
||||
use_timestamp_check = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--dry-run")) {
|
||||
dry_run = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -49,16 +59,59 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
// Calculate commit ID (SHA256 of directory tree)
|
||||
const commit_id = try crypto.hashDirectory(allocator, path);
|
||||
defer allocator.free(commit_id);
|
||||
// Detect if path is a subdirectory by finding git root
|
||||
const git_root = try findGitRoot(allocator, path);
|
||||
defer if (git_root) |gr| allocator.free(gr);
|
||||
|
||||
// Construct remote destination path
|
||||
const remote_path = try std.fmt.allocPrint(
|
||||
const is_subdir = git_root != null and !std.mem.eql(u8, git_root.?, path);
|
||||
const relative_path = if (is_subdir) blk: {
|
||||
// Get relative path from git root to the specified path
|
||||
break :blk try std.fs.path.relative(allocator, git_root.?, path);
|
||||
} else null;
|
||||
defer if (relative_path) |rp| allocator.free(rp);
|
||||
|
||||
// Determine commit_id and remote path based on mode
|
||||
const commit_id: []const u8 = if (dev_mode) blk: {
|
||||
// Dev mode: skip expensive hashing, use fixed "dev" commit
|
||||
break :blk "dev";
|
||||
} else blk: {
|
||||
// Production mode: calculate SHA256 of directory tree (always from git root)
|
||||
const hash_base = git_root orelse path;
|
||||
break :blk try crypto.hashDirectory(allocator, hash_base);
|
||||
};
|
||||
defer if (!dev_mode) allocator.free(commit_id);
|
||||
|
||||
// In dev mode, sync to {worker_base}/dev/files/ instead of hashed path
|
||||
// For subdirectories, append the relative path to the remote destination
|
||||
const remote_path = if (dev_mode) blk: {
|
||||
if (is_subdir) {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/dev/files/{s}/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base, relative_path.? },
|
||||
);
|
||||
} else {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/dev/files/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base },
|
||||
);
|
||||
}
|
||||
} else blk: {
|
||||
if (is_subdir) {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/{s}/files/{s}/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base, commit_id, relative_path.? },
|
||||
);
|
||||
} else {
|
||||
break :blk try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/{s}/files/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base, commit_id },
|
||||
);
|
||||
}
|
||||
};
|
||||
defer allocator.free(remote_path);
|
||||
|
||||
// Sync using embedded rsync (no external binary needed)
|
||||
|
|
@ -102,6 +155,9 @@ fn printUsage() void {
|
|||
logging.err(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
|
||||
logging.err(" --monitor Wait and show basic sync progress\n", .{});
|
||||
logging.err(" --json Output machine-readable JSON (sync result only)\n", .{});
|
||||
logging.err(" --dev Dev mode: skip hashing, use fixed path (fast)\n", .{});
|
||||
logging.err(" --check-timestamp Skip files unchanged since last sync\n", .{});
|
||||
logging.err(" --dry-run Show what would be synced without transferring\n", .{});
|
||||
logging.err(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
|
||||
|
|
@ -175,3 +231,29 @@ fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, comm
|
|||
std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{});
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the git root directory by walking up from the given path
|
||||
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
|
||||
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
const path = try std.fs.realpath(start_path, &buf);
|
||||
|
||||
var current = path;
|
||||
while (true) {
|
||||
// Check if .git exists in current directory
|
||||
const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" });
|
||||
defer allocator.free(git_path);
|
||||
|
||||
if (std.fs.accessAbsolute(git_path, .{})) {
|
||||
// Found .git directory
|
||||
return try allocator.dupe(u8, current);
|
||||
} else |_| {
|
||||
// .git not found here, try parent
|
||||
const parent = std.fs.path.dirname(current);
|
||||
if (parent == null or std.mem.eql(u8, parent.?, current)) {
|
||||
// Reached root without finding .git
|
||||
return null;
|
||||
}
|
||||
current = parent.?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -220,6 +220,36 @@ pub const Client = struct {
|
|||
try builder.send(stream);
|
||||
}
|
||||
|
||||
pub fn sendQueryJobByCommit(self: *Client, job_name: []const u8, commit_id: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
try validateCommitId(commit_id);
|
||||
try validateJobName(job_name);
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] [commit_id: 20 bytes]
|
||||
const total_len = 1 + 16 + 1 + job_name.len + 20;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.query_job);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(job_name.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + job_name.len], job_name);
|
||||
offset += job_name.len;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 20], commit_id);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ pub const Opcode = enum(u8) {
|
|||
|
||||
validate_request = 0x16,
|
||||
|
||||
// Job query opcode
|
||||
query_job = 0x23,
|
||||
|
||||
// Logs and debug opcodes
|
||||
get_logs = 0x20,
|
||||
stream_logs = 0x21,
|
||||
|
|
@ -68,6 +71,7 @@ pub const restore_jupyter = Opcode.restore_jupyter;
|
|||
pub const list_jupyter = Opcode.list_jupyter;
|
||||
pub const list_jupyter_packages = Opcode.list_jupyter_packages;
|
||||
pub const validate_request = Opcode.validate_request;
|
||||
pub const query_job = Opcode.query_job;
|
||||
pub const get_logs = Opcode.get_logs;
|
||||
pub const stream_logs = Opcode.stream_logs;
|
||||
pub const attach_debug = Opcode.attach_debug;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
const std = @import("std");
|
||||
const ignore = @import("ignore.zig");
|
||||
|
||||
pub fn encodeHexLower(allocator: std.mem.Allocator, bytes: []const u8) ![]u8 {
|
||||
const hex = try allocator.alloc(u8, bytes.len * 2);
|
||||
|
|
@ -105,13 +106,20 @@ pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths:
|
|||
return encodeHexLower(allocator, &hash);
|
||||
}
|
||||
|
||||
/// Calculate commit ID for a directory (SHA256 of tree state)
|
||||
/// Calculate commit ID for a directory (SHA256 of tree state, respecting .gitignore)
|
||||
pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
// Load .gitignore and .mlignore patterns
|
||||
var gitignore = ignore.GitIgnore.init(allocator);
|
||||
defer gitignore.deinit();
|
||||
|
||||
try gitignore.loadFromDir(dir_path, ".gitignore");
|
||||
try gitignore.loadFromDir(dir_path, ".mlignore");
|
||||
|
||||
var walker = try dir.walk(allocator);
|
||||
defer walker.deinit(allocator);
|
||||
|
||||
|
|
@ -124,6 +132,12 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
|
|||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind == .file) {
|
||||
// Skip files matching default ignores
|
||||
if (ignore.matchesDefaultIgnore(entry.path)) continue;
|
||||
|
||||
// Skip files matching .gitignore/.mlignore patterns
|
||||
if (gitignore.isIgnored(entry.path, false)) continue;
|
||||
|
||||
try paths.append(allocator, try allocator.dupe(u8, entry.path));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
333
cli/src/utils/hash_cache.zig
Normal file
333
cli/src/utils/hash_cache.zig
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
const std = @import("std");
|
||||
const crypto = @import("crypto.zig");
|
||||
const json = @import("json.zig");
|
||||
|
||||
/// Cache entry for a single file
|
||||
const CacheEntry = struct {
|
||||
mtime: i64,
|
||||
hash: []const u8,
|
||||
|
||||
pub fn deinit(self: *const CacheEntry, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.hash);
|
||||
}
|
||||
};
|
||||
|
||||
/// Hash cache that stores file mtimes and hashes to avoid re-hashing unchanged files
|
||||
pub const HashCache = struct {
|
||||
entries: std.StringHashMap(CacheEntry),
|
||||
allocator: std.mem.Allocator,
|
||||
cache_path: []const u8,
|
||||
dirty: bool,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) HashCache {
|
||||
return .{
|
||||
.entries = std.StringHashMap(CacheEntry).init(allocator),
|
||||
.allocator = allocator,
|
||||
.cache_path = "",
|
||||
.dirty = false,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *HashCache) void {
|
||||
var it = self.entries.iterator();
|
||||
while (it.next()) |entry| {
|
||||
entry.value_ptr.deinit(self.allocator);
|
||||
self.allocator.free(entry.key_ptr.*);
|
||||
}
|
||||
self.entries.deinit();
|
||||
if (self.cache_path.len > 0) {
|
||||
self.allocator.free(self.cache_path);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get default cache path: ~/.ml/cache/hashes.json
|
||||
pub fn getDefaultPath(allocator: std.mem.Allocator) ![]const u8 {
|
||||
const home = std.posix.getenv("HOME") orelse {
|
||||
return error.NoHomeDirectory;
|
||||
};
|
||||
|
||||
// Ensure cache directory exists
|
||||
const cache_dir = try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache" });
|
||||
defer allocator.free(cache_dir);
|
||||
|
||||
std.fs.cwd().makeDir(cache_dir) catch |err| switch (err) {
|
||||
error.PathAlreadyExists => {},
|
||||
else => return err,
|
||||
};
|
||||
|
||||
return try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache", "hashes.json" });
|
||||
}
|
||||
|
||||
/// Load cache from disk
|
||||
pub fn load(self: *HashCache) !void {
|
||||
const cache_path = try getDefaultPath(self.allocator);
|
||||
self.cache_path = cache_path;
|
||||
|
||||
const file = std.fs.cwd().openFile(cache_path, .{}) catch |err| switch (err) {
|
||||
error.FileNotFound => return, // No cache yet is fine
|
||||
else => return err,
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
const content = try file.readToEndAlloc(self.allocator, 10 * 1024 * 1024); // Max 10MB
|
||||
defer self.allocator.free(content);
|
||||
|
||||
// Parse JSON
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, self.allocator, content, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value.object;
|
||||
const version = root.get("version") orelse return error.InvalidCacheFormat;
|
||||
if (version.integer != 1) return error.UnsupportedCacheVersion;
|
||||
|
||||
const files = root.get("files") orelse return error.InvalidCacheFormat;
|
||||
if (files.object.count() == 0) return;
|
||||
|
||||
var it = files.object.iterator();
|
||||
while (it.next()) |entry| {
|
||||
const path = try self.allocator.dupe(u8, entry.key_ptr.*);
|
||||
|
||||
const file_obj = entry.value_ptr.object;
|
||||
const mtime = file_obj.get("mtime") orelse continue;
|
||||
const hash_val = file_obj.get("hash") orelse continue;
|
||||
|
||||
const hash = try self.allocator.dupe(u8, hash_val.string);
|
||||
|
||||
try self.entries.put(path, .{
|
||||
.mtime = mtime.integer,
|
||||
.hash = hash,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Save cache to disk
|
||||
pub fn save(self: *HashCache) !void {
|
||||
if (!self.dirty) return;
|
||||
|
||||
var json_str = std.ArrayList(u8).init(self.allocator);
|
||||
defer json_str.deinit();
|
||||
|
||||
var writer = json_str.writer();
|
||||
|
||||
// Write header
|
||||
try writer.print("{{\n \"version\": 1,\n \"files\": {{\n", .{});
|
||||
|
||||
// Write entries
|
||||
var it = self.entries.iterator();
|
||||
var first = true;
|
||||
while (it.next()) |entry| {
|
||||
if (!first) try writer.print(",\n", .{});
|
||||
first = false;
|
||||
|
||||
// Escape path for JSON
|
||||
const escaped_path = try json.escapeString(self.allocator, entry.key_ptr.*);
|
||||
defer self.allocator.free(escaped_path);
|
||||
|
||||
try writer.print(" \"{s}\": {{\"mtime\": {d}, \"hash\": \"{s}\"}}", .{
|
||||
escaped_path,
|
||||
entry.value_ptr.mtime,
|
||||
entry.value_ptr.hash,
|
||||
});
|
||||
}
|
||||
|
||||
// Write footer
|
||||
try writer.print("\n }}\n}}\n", .{});
|
||||
|
||||
// Write atomically
|
||||
const tmp_path = try std.fmt.allocPrint(self.allocator, "{s}.tmp", .{self.cache_path});
|
||||
defer self.allocator.free(tmp_path);
|
||||
|
||||
{
|
||||
const file = try std.fs.cwd().createFile(tmp_path, .{});
|
||||
defer file.close();
|
||||
try file.writeAll(json_str.items);
|
||||
}
|
||||
|
||||
try std.fs.cwd().rename(tmp_path, self.cache_path);
|
||||
self.dirty = false;
|
||||
}
|
||||
|
||||
/// Check if file needs re-hashing
|
||||
pub fn needsHash(self: *HashCache, path: []const u8, mtime: i64) bool {
|
||||
const entry = self.entries.get(path) orelse return true;
|
||||
return entry.mtime != mtime;
|
||||
}
|
||||
|
||||
/// Get cached hash for file
|
||||
pub fn getHash(self: *HashCache, path: []const u8, mtime: i64) ?[]const u8 {
|
||||
const entry = self.entries.get(path) orelse return null;
|
||||
if (entry.mtime != mtime) return null;
|
||||
return entry.hash;
|
||||
}
|
||||
|
||||
/// Store hash for file
|
||||
pub fn putHash(self: *HashCache, path: []const u8, mtime: i64, hash: []const u8) !void {
|
||||
const path_copy = try self.allocator.dupe(u8, path);
|
||||
|
||||
// Remove old entry if exists
|
||||
if (self.entries.fetchRemove(path_copy)) |old| {
|
||||
self.allocator.free(old.key);
|
||||
old.value.deinit(self.allocator);
|
||||
}
|
||||
|
||||
const hash_copy = try self.allocator.dupe(u8, hash);
|
||||
|
||||
try self.entries.put(path_copy, .{
|
||||
.mtime = mtime,
|
||||
.hash = hash_copy,
|
||||
});
|
||||
|
||||
self.dirty = true;
|
||||
}
|
||||
|
||||
/// Clear cache (e.g., after git checkout)
|
||||
pub fn clear(self: *HashCache) void {
|
||||
var it = self.entries.iterator();
|
||||
while (it.next()) |entry| {
|
||||
entry.value_ptr.deinit(self.allocator);
|
||||
self.allocator.free(entry.key_ptr.*);
|
||||
}
|
||||
self.entries.clearRetainingCapacity();
|
||||
self.dirty = true;
|
||||
}
|
||||
|
||||
/// Get cache stats
|
||||
pub fn getStats(self: *HashCache) struct { entries: usize, dirty: bool } {
|
||||
return .{
|
||||
.entries = self.entries.count(),
|
||||
.dirty = self.dirty,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Calculate directory hash with cache support
|
||||
pub fn hashDirectoryWithCache(
|
||||
allocator: std.mem.Allocator,
|
||||
dir_path: []const u8,
|
||||
cache: *HashCache,
|
||||
) ![]const u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
// Load .gitignore patterns
|
||||
var gitignore = @import("ignore.zig").GitIgnore.init(allocator);
|
||||
defer gitignore.deinit();
|
||||
|
||||
try gitignore.loadFromDir(dir_path, ".gitignore");
|
||||
try gitignore.loadFromDir(dir_path, ".mlignore");
|
||||
|
||||
var walker = try dir.walk(allocator);
|
||||
defer walker.deinit(allocator);
|
||||
|
||||
// Collect paths and check cache
|
||||
var paths: std.ArrayList(struct { path: []const u8, mtime: i64, use_cache: bool }) = .{};
|
||||
defer {
|
||||
for (paths.items) |p| allocator.free(p.path);
|
||||
paths.deinit(allocator);
|
||||
}
|
||||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind == .file) {
|
||||
// Skip files matching default ignores
|
||||
if (@import("ignore.zig").matchesDefaultIgnore(entry.path)) continue;
|
||||
|
||||
// Skip files matching .gitignore/.mlignore patterns
|
||||
if (gitignore.isIgnored(entry.path, false)) continue;
|
||||
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, entry.path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const stat = dir.statFile(entry.path) catch |err| switch (err) {
|
||||
error.FileNotFound => continue,
|
||||
else => return err,
|
||||
};
|
||||
|
||||
const mtime = @as(i64, @intCast(stat.mtime));
|
||||
const use_cache = !cache.needsHash(entry.path, mtime);
|
||||
|
||||
try paths.append(.{
|
||||
.path = try allocator.dupe(u8, entry.path),
|
||||
.mtime = mtime,
|
||||
.use_cache = use_cache,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort paths for deterministic hashing
|
||||
std.sort.block(
|
||||
struct { path: []const u8, mtime: i64, use_cache: bool },
|
||||
paths.items,
|
||||
{},
|
||||
struct {
|
||||
fn lessThan(_: void, a: anytype, b: anytype) bool {
|
||||
return std.mem.order(u8, a.path, b.path) == .lt;
|
||||
}
|
||||
}.lessThan,
|
||||
);
|
||||
|
||||
// Hash each file (using cache where possible)
|
||||
for (paths.items) |item| {
|
||||
hasher.update(item.path);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
|
||||
const file_hash: []const u8 = if (item.use_cache)
|
||||
cache.getHash(item.path, item.mtime).?
|
||||
else blk: {
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, item.path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const hash = try crypto.hashFile(allocator, full_path);
|
||||
try cache.putHash(item.path, item.mtime, hash);
|
||||
break :blk hash;
|
||||
};
|
||||
defer if (!item.use_cache) allocator.free(file_hash);
|
||||
|
||||
hasher.update(file_hash);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
}
|
||||
|
||||
var hash: [32]u8 = undefined;
|
||||
hasher.final(&hash);
|
||||
return crypto.encodeHexLower(allocator, &hash);
|
||||
}
|
||||
|
||||
test "HashCache basic operations" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var cache = HashCache.init(allocator);
|
||||
defer cache.deinit();
|
||||
|
||||
// Put and get
|
||||
try cache.putHash("src/main.py", 1708369200, "abc123");
|
||||
|
||||
const hash = cache.getHash("src/main.py", 1708369200);
|
||||
try std.testing.expect(hash != null);
|
||||
try std.testing.expectEqualStrings("abc123", hash.?);
|
||||
|
||||
// Wrong mtime should return null
|
||||
const stale = cache.getHash("src/main.py", 1708369201);
|
||||
try std.testing.expect(stale == null);
|
||||
|
||||
// needsHash should detect stale entries
|
||||
try std.testing.expect(cache.needsHash("src/main.py", 1708369201));
|
||||
try std.testing.expect(!cache.needsHash("src/main.py", 1708369200));
|
||||
}
|
||||
|
||||
test "HashCache clear" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var cache = HashCache.init(allocator);
|
||||
defer cache.deinit();
|
||||
|
||||
try cache.putHash("file1.py", 123, "hash1");
|
||||
try cache.putHash("file2.py", 456, "hash2");
|
||||
|
||||
try std.testing.expectEqual(@as(usize, 2), cache.getStats().entries);
|
||||
|
||||
cache.clear();
|
||||
|
||||
try std.testing.expectEqual(@as(usize, 0), cache.getStats().entries);
|
||||
try std.testing.expect(cache.getStats().dirty);
|
||||
}
|
||||
261
cli/src/utils/ignore.zig
Normal file
261
cli/src/utils/ignore.zig
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Pattern type for ignore rules
|
||||
const Pattern = struct {
|
||||
pattern: []const u8,
|
||||
is_negation: bool, // true if pattern starts with !
|
||||
is_dir_only: bool, // true if pattern ends with /
|
||||
anchored: bool, // true if pattern contains / (not at start)
|
||||
};
|
||||
|
||||
/// GitIgnore matcher for filtering files during directory traversal
|
||||
pub const GitIgnore = struct {
|
||||
patterns: std.ArrayList(Pattern),
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) GitIgnore {
|
||||
return .{
|
||||
.patterns = std.ArrayList(Pattern).init(allocator),
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *GitIgnore) void {
|
||||
for (self.patterns.items) |p| {
|
||||
self.allocator.free(p.pattern);
|
||||
}
|
||||
self.patterns.deinit();
|
||||
}
|
||||
|
||||
/// Load .gitignore or .mlignore from directory
|
||||
pub fn loadFromDir(self: *GitIgnore, dir_path: []const u8, filename: []const u8) !void {
|
||||
const path = try std.fs.path.join(self.allocator, &[_][]const u8{ dir_path, filename });
|
||||
defer self.allocator.free(path);
|
||||
|
||||
const file = std.fs.cwd().openFile(path, .{}) catch |err| switch (err) {
|
||||
error.FileNotFound => return, // No ignore file is fine
|
||||
else => return err,
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
const content = try file.readToEndAlloc(self.allocator, 1024 * 1024); // Max 1MB
|
||||
defer self.allocator.free(content);
|
||||
|
||||
try self.parse(content);
|
||||
}
|
||||
|
||||
/// Parse ignore patterns from content
|
||||
pub fn parse(self: *GitIgnore, content: []const u8) !void {
|
||||
var lines = std.mem.split(u8, content, "\n");
|
||||
while (lines.next()) |line| {
|
||||
const trimmed = std.mem.trim(u8, line, " \t\r");
|
||||
|
||||
// Skip empty lines and comments
|
||||
if (trimmed.len == 0 or std.mem.startsWith(u8, trimmed, "#")) continue;
|
||||
|
||||
try self.addPattern(trimmed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single pattern
|
||||
fn addPattern(self: *GitIgnore, pattern: []const u8) !void {
|
||||
var p = pattern;
|
||||
var is_negation = false;
|
||||
var is_dir_only = false;
|
||||
|
||||
// Check for negation
|
||||
if (std.mem.startsWith(u8, p, "!")) {
|
||||
is_negation = true;
|
||||
p = p[1..];
|
||||
}
|
||||
|
||||
// Check for directory-only marker
|
||||
if (std.mem.endsWith(u8, p, "/")) {
|
||||
is_dir_only = true;
|
||||
p = p[0 .. p.len - 1];
|
||||
}
|
||||
|
||||
// Remove leading slash (anchored patterns)
|
||||
const anchored = std.mem.indexOf(u8, p, "/") != null;
|
||||
if (std.mem.startsWith(u8, p, "/")) {
|
||||
p = p[1..];
|
||||
}
|
||||
|
||||
// Store normalized pattern
|
||||
const pattern_copy = try self.allocator.dupe(u8, p);
|
||||
try self.patterns.append(.{
|
||||
.pattern = pattern_copy,
|
||||
.is_negation = is_negation,
|
||||
.is_dir_only = is_dir_only,
|
||||
.anchored = anchored,
|
||||
});
|
||||
}
|
||||
|
||||
/// Check if a path should be ignored
|
||||
pub fn isIgnored(self: *GitIgnore, path: []const u8, is_dir: bool) bool {
|
||||
var ignored = false;
|
||||
|
||||
for (self.patterns.items) |pattern| {
|
||||
if (self.matches(pattern, path, is_dir)) {
|
||||
ignored = !pattern.is_negation;
|
||||
}
|
||||
}
|
||||
|
||||
return ignored;
|
||||
}
|
||||
|
||||
/// Check if a single pattern matches
|
||||
fn matches(self: *GitIgnore, pattern: Pattern, path: []const u8, is_dir: bool) bool {
|
||||
_ = self;
|
||||
|
||||
// Directory-only patterns only match directories
|
||||
if (pattern.is_dir_only and !is_dir) return false;
|
||||
|
||||
// Convert gitignore pattern to glob
|
||||
if (patternMatch(pattern.pattern, path)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Also check basename match for non-anchored patterns
|
||||
if (!pattern.anchored) {
|
||||
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
|
||||
const basename = path[idx + 1 ..];
|
||||
if (patternMatch(pattern.pattern, basename)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Simple glob pattern matching
|
||||
fn patternMatch(pattern: []const u8, path: []const u8) bool {
|
||||
var p_idx: usize = 0;
|
||||
var s_idx: usize = 0;
|
||||
|
||||
while (p_idx < pattern.len) {
|
||||
const p_char = pattern[p_idx];
|
||||
|
||||
if (p_char == '*') {
|
||||
// Handle ** (matches any number of directories)
|
||||
if (p_idx + 1 < pattern.len and pattern[p_idx + 1] == '*') {
|
||||
// ** matches everything
|
||||
return true;
|
||||
}
|
||||
|
||||
// Single * matches anything until next / or end
|
||||
p_idx += 1;
|
||||
if (p_idx >= pattern.len) {
|
||||
// * at end - match rest of path
|
||||
return true;
|
||||
}
|
||||
|
||||
const next_char = pattern[p_idx];
|
||||
while (s_idx < path.len and path[s_idx] != next_char) {
|
||||
s_idx += 1;
|
||||
}
|
||||
} else if (p_char == '?') {
|
||||
// ? matches single character
|
||||
if (s_idx >= path.len) return false;
|
||||
p_idx += 1;
|
||||
s_idx += 1;
|
||||
} else {
|
||||
// Literal character match
|
||||
if (s_idx >= path.len or path[s_idx] != p_char) return false;
|
||||
p_idx += 1;
|
||||
s_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return s_idx == path.len;
|
||||
}
|
||||
};
|
||||
|
||||
/// Default patterns always ignored (like git does)
|
||||
pub const DEFAULT_IGNORES = [_][]const u8{
|
||||
".git",
|
||||
".ml",
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
".DS_Store",
|
||||
"node_modules",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
};
|
||||
|
||||
/// Check if path matches default ignores
|
||||
pub fn matchesDefaultIgnore(path: []const u8) bool {
|
||||
// Check exact matches
|
||||
for (DEFAULT_IGNORES) |pattern| {
|
||||
if (std.mem.eql(u8, path, pattern)) return true;
|
||||
}
|
||||
|
||||
// Check suffix matches for patterns like *.pyc
|
||||
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
|
||||
const basename = path[idx + 1 ..];
|
||||
for (DEFAULT_IGNORES) |pattern| {
|
||||
if (std.mem.startsWith(u8, pattern, "*.")) {
|
||||
const ext = pattern[1..]; // Get extension including dot
|
||||
if (std.mem.endsWith(u8, basename, ext)) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
test "GitIgnore basic patterns" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gi = GitIgnore.init(allocator);
|
||||
defer gi.deinit();
|
||||
|
||||
try gi.parse("node_modules\n__pycache__\n*.pyc\n");
|
||||
|
||||
try std.testing.expect(gi.isIgnored("node_modules", true));
|
||||
try std.testing.expect(gi.isIgnored("__pycache__", true));
|
||||
try std.testing.expect(gi.isIgnored("test.pyc", false));
|
||||
try std.testing.expect(!gi.isIgnored("main.py", false));
|
||||
}
|
||||
|
||||
test "GitIgnore negation" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gi = GitIgnore.init(allocator);
|
||||
defer gi.deinit();
|
||||
|
||||
try gi.parse("*.log\n!important.log\n");
|
||||
|
||||
try std.testing.expect(gi.isIgnored("debug.log", false));
|
||||
try std.testing.expect(!gi.isIgnored("important.log", false));
|
||||
}
|
||||
|
||||
test "GitIgnore directory-only" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gi = GitIgnore.init(allocator);
|
||||
defer gi.deinit();
|
||||
|
||||
try gi.parse("build/\n");
|
||||
|
||||
try std.testing.expect(gi.isIgnored("build", true));
|
||||
try std.testing.expect(!gi.isIgnored("build", false));
|
||||
}
|
||||
|
||||
test "matchesDefaultIgnore" {
|
||||
try std.testing.expect(matchesDefaultIgnore(".git"));
|
||||
try std.testing.expect(matchesDefaultIgnore("__pycache__"));
|
||||
try std.testing.expect(matchesDefaultIgnore("node_modules"));
|
||||
try std.testing.expect(matchesDefaultIgnore("test.pyc"));
|
||||
try std.testing.expect(!matchesDefaultIgnore("main.py"));
|
||||
}
|
||||
195
cli/src/utils/native_hash.zig
Normal file
195
cli/src/utils/native_hash.zig
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
const std = @import("std");
|
||||
const c = @cImport({
|
||||
@cInclude("dataset_hash.h");
|
||||
});
|
||||
|
||||
/// Native hash context for high-performance file hashing
|
||||
pub const NativeHasher = struct {
|
||||
ctx: *c.fh_context_t,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
/// Initialize native hasher with thread pool
|
||||
/// num_threads: 0 = auto-detect (use hardware concurrency)
|
||||
pub fn init(allocator: std.mem.Allocator, num_threads: u32) !NativeHasher {
|
||||
const ctx = c.fh_init(num_threads);
|
||||
if (ctx == null) return error.NativeInitFailed;
|
||||
|
||||
return .{
|
||||
.ctx = ctx,
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
/// Cleanup native hasher and thread pool
|
||||
pub fn deinit(self: *NativeHasher) void {
|
||||
c.fh_cleanup(self.ctx);
|
||||
}
|
||||
|
||||
/// Hash a single file
|
||||
pub fn hashFile(self: *NativeHasher, path: []const u8) ![]const u8 {
|
||||
const c_path = try self.allocator.dupeZ(u8, path);
|
||||
defer self.allocator.free(c_path);
|
||||
|
||||
const result = c.fh_hash_file(self.ctx, c_path.ptr);
|
||||
if (result == null) return error.HashFailed;
|
||||
defer c.fh_free_string(result);
|
||||
|
||||
return try self.allocator.dupe(u8, std.mem.span(result));
|
||||
}
|
||||
|
||||
/// Batch hash multiple files (amortizes CGo overhead)
|
||||
pub fn hashBatch(self: *NativeHasher, paths: []const []const u8) ![][]const u8 {
|
||||
// Convert paths to C string array
|
||||
const c_paths = try self.allocator.alloc([*c]const u8, paths.len);
|
||||
defer self.allocator.free(c_paths);
|
||||
|
||||
for (paths, 0..) |path, i| {
|
||||
const c_path = try self.allocator.dupeZ(u8, path);
|
||||
c_paths[i] = c_path.ptr;
|
||||
// Note: we need to keep these alive until after fh_hash_batch
|
||||
}
|
||||
defer {
|
||||
for (c_paths) |p| {
|
||||
self.allocator.free(std.mem.span(p));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate results array
|
||||
const results = try self.allocator.alloc([*c]u8, paths.len);
|
||||
defer self.allocator.free(results);
|
||||
|
||||
// Call native batch hash
|
||||
const ret = c.fh_hash_batch(self.ctx, c_paths.ptr, @intCast(paths.len), results.ptr);
|
||||
if (ret != 0) return error.HashFailed;
|
||||
|
||||
// Convert results to Zig strings
|
||||
var hashes = try self.allocator.alloc([]const u8, paths.len);
|
||||
errdefer {
|
||||
for (hashes) |h| self.allocator.free(h);
|
||||
self.allocator.free(hashes);
|
||||
}
|
||||
|
||||
for (results, 0..) |r, i| {
|
||||
hashes[i] = try self.allocator.dupe(u8, std.mem.span(r));
|
||||
c.fh_free_string(r);
|
||||
}
|
||||
|
||||
return hashes;
|
||||
}
|
||||
|
||||
/// Hash entire directory (combined hash)
|
||||
pub fn hashDirectory(self: *NativeHasher, dir_path: []const u8) ![]const u8 {
|
||||
const c_path = try self.allocator.dupeZ(u8, dir_path);
|
||||
defer self.allocator.free(c_path);
|
||||
|
||||
const result = c.fh_hash_directory(self.ctx, c_path.ptr);
|
||||
if (result == null) return error.HashFailed;
|
||||
defer c.fh_free_string(result);
|
||||
|
||||
return try self.allocator.dupe(u8, std.mem.span(result));
|
||||
}
|
||||
|
||||
/// Hash directory with batch output (individual file hashes)
|
||||
pub fn hashDirectoryBatch(
|
||||
self: *NativeHasher,
|
||||
dir_path: []const u8,
|
||||
max_results: u32,
|
||||
) !struct { hashes: [][]const u8, paths: [][]const u8, count: u32 } {
|
||||
const c_path = try self.allocator.dupeZ(u8, dir_path);
|
||||
defer self.allocator.free(c_path);
|
||||
|
||||
// Allocate output arrays
|
||||
const hashes = try self.allocator.alloc([*c]u8, max_results);
|
||||
defer self.allocator.free(hashes);
|
||||
|
||||
const paths = try self.allocator.alloc([*c]u8, max_results);
|
||||
defer self.allocator.free(paths);
|
||||
|
||||
var count: u32 = 0;
|
||||
|
||||
const ret = c.fh_hash_directory_batch(
|
||||
self.ctx,
|
||||
c_path.ptr,
|
||||
hashes.ptr,
|
||||
paths.ptr,
|
||||
max_results,
|
||||
&count,
|
||||
);
|
||||
if (ret != 0) return error.HashFailed;
|
||||
|
||||
// Convert to Zig arrays
|
||||
var zig_hashes = try self.allocator.alloc([]const u8, count);
|
||||
errdefer {
|
||||
for (zig_hashes) |h| self.allocator.free(h);
|
||||
self.allocator.free(zig_hashes);
|
||||
}
|
||||
|
||||
var zig_paths = try self.allocator.alloc([]const u8, count);
|
||||
errdefer {
|
||||
for (zig_paths) |p| self.allocator.free(p);
|
||||
self.allocator.free(zig_paths);
|
||||
}
|
||||
|
||||
for (0..count) |i| {
|
||||
zig_hashes[i] = try self.allocator.dupe(u8, std.mem.span(hashes[i]));
|
||||
c.fh_free_string(hashes[i]);
|
||||
|
||||
zig_paths[i] = try self.allocator.dupe(u8, std.mem.span(paths[i]));
|
||||
c.fh_free_string(paths[i]);
|
||||
}
|
||||
|
||||
return .{
|
||||
.hashes = zig_hashes,
|
||||
.paths = zig_paths,
|
||||
.count = count,
|
||||
};
|
||||
}
|
||||
|
||||
/// Check if SIMD SHA-256 is available
|
||||
pub fn hasSimd(self: *NativeHasher) bool {
|
||||
_ = self;
|
||||
return c.fh_has_simd_sha256() != 0;
|
||||
}
|
||||
|
||||
/// Get implementation info (SIMD type, etc.)
|
||||
pub fn getImplInfo(self: *NativeHasher) []const u8 {
|
||||
_ = self;
|
||||
return std.mem.span(c.fh_get_simd_impl_name());
|
||||
}
|
||||
};
|
||||
|
||||
/// Convenience function: hash directory using native library
|
||||
pub fn hashDirectoryNative(allocator: std.mem.Allocator, dir_path: []const u8) ![]const u8 {
|
||||
var hasher = try NativeHasher.init(allocator, 0); // Auto-detect threads
|
||||
defer hasher.deinit();
|
||||
return try hasher.hashDirectory(dir_path);
|
||||
}
|
||||
|
||||
/// Convenience function: batch hash files using native library
|
||||
pub fn hashFilesNative(
|
||||
allocator: std.mem.Allocator,
|
||||
paths: []const []const u8,
|
||||
) ![][]const u8 {
|
||||
var hasher = try NativeHasher.init(allocator, 0);
|
||||
defer hasher.deinit();
|
||||
return try hasher.hashBatch(paths);
|
||||
}
|
||||
|
||||
test "NativeHasher basic operations" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
// Skip if native library not available
|
||||
var hasher = NativeHasher.init(allocator, 1) catch |err| {
|
||||
if (err == error.NativeInitFailed) {
|
||||
std.debug.print("Native library not available, skipping test\n", .{});
|
||||
return;
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer hasher.deinit();
|
||||
|
||||
// Check SIMD availability
|
||||
const has_simd = hasher.hasSimd();
|
||||
const impl_name = hasher.getImplInfo();
|
||||
std.debug.print("SIMD: {any}, Impl: {s}\n", .{ has_simd, impl_name });
|
||||
}
|
||||
231
cli/src/utils/parallel_walk.zig
Normal file
231
cli/src/utils/parallel_walk.zig
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
const std = @import("std");
|
||||
const ignore = @import("ignore.zig");
|
||||
|
||||
/// Thread-safe work queue for parallel directory walking
|
||||
const WorkQueue = struct {
|
||||
items: std.ArrayList(WorkItem),
|
||||
mutex: std.Thread.Mutex,
|
||||
condition: std.Thread.Condition,
|
||||
done: bool,
|
||||
|
||||
const WorkItem = struct {
|
||||
path: []const u8,
|
||||
depth: usize,
|
||||
};
|
||||
|
||||
fn init(allocator: std.mem.Allocator) WorkQueue {
|
||||
return .{
|
||||
.items = std.ArrayList(WorkItem).init(allocator),
|
||||
.mutex = .{},
|
||||
.condition = .{},
|
||||
.done = false,
|
||||
};
|
||||
}
|
||||
|
||||
fn deinit(self: *WorkQueue, allocator: std.mem.Allocator) void {
|
||||
for (self.items.items) |item| {
|
||||
allocator.free(item.path);
|
||||
}
|
||||
self.items.deinit();
|
||||
}
|
||||
|
||||
fn push(self: *WorkQueue, path: []const u8, depth: usize, allocator: std.mem.Allocator) !void {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
|
||||
try self.items.append(.{
|
||||
.path = try allocator.dupe(u8, path),
|
||||
.depth = depth,
|
||||
});
|
||||
self.condition.signal();
|
||||
}
|
||||
|
||||
fn pop(self: *WorkQueue) ?WorkItem {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
|
||||
while (self.items.items.len == 0 and !self.done) {
|
||||
self.condition.wait(&self.mutex);
|
||||
}
|
||||
|
||||
if (self.items.items.len == 0) return null;
|
||||
return self.items.pop();
|
||||
}
|
||||
|
||||
fn setDone(self: *WorkQueue) void {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
self.done = true;
|
||||
self.condition.broadcast();
|
||||
}
|
||||
};
|
||||
|
||||
/// Result from parallel directory walk
|
||||
const WalkResult = struct {
|
||||
files: std.ArrayList([]const u8),
|
||||
mutex: std.Thread.Mutex,
|
||||
|
||||
fn init(allocator: std.mem.Allocator) WalkResult {
|
||||
return .{
|
||||
.files = std.ArrayList([]const u8).init(allocator),
|
||||
.mutex = .{},
|
||||
};
|
||||
}
|
||||
|
||||
fn deinit(self: *WalkResult, allocator: std.mem.Allocator) void {
|
||||
for (self.files.items) |file| {
|
||||
allocator.free(file);
|
||||
}
|
||||
self.files.deinit();
|
||||
}
|
||||
|
||||
fn add(self: *WalkResult, path: []const u8, allocator: std.mem.Allocator) !void {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
try self.files.append(try allocator.dupe(u8, path));
|
||||
}
|
||||
};
|
||||
|
||||
/// Thread context for parallel walking
|
||||
const ThreadContext = struct {
|
||||
queue: *WorkQueue,
|
||||
result: *WalkResult,
|
||||
gitignore: *ignore.GitIgnore,
|
||||
base_path: []const u8,
|
||||
allocator: std.mem.Allocator,
|
||||
max_depth: usize,
|
||||
};
|
||||
|
||||
/// Worker thread function for parallel directory walking
|
||||
fn walkWorker(ctx: *ThreadContext) void {
|
||||
while (true) {
|
||||
const item = ctx.queue.pop() orelse break;
|
||||
defer ctx.allocator.free(item.path);
|
||||
|
||||
if (item.depth >= ctx.max_depth) continue;
|
||||
|
||||
walkDirectoryParallel(ctx, item.path, item.depth) catch |err| {
|
||||
std.log.warn("Error walking {s}: {any}", .{ item.path, err });
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Walk a single directory and add subdirectories to queue
|
||||
fn walkDirectoryParallel(ctx: *ThreadContext, dir_path: []const u8, depth: usize) !void {
|
||||
const full_path = if (std.mem.eql(u8, dir_path, "."))
|
||||
ctx.base_path
|
||||
else
|
||||
try std.fs.path.join(ctx.allocator, &[_][]const u8{ ctx.base_path, dir_path });
|
||||
defer if (!std.mem.eql(u8, dir_path, ".")) ctx.allocator.free(full_path);
|
||||
|
||||
var dir = std.fs.cwd().openDir(full_path, .{ .iterate = true }) catch |err| switch (err) {
|
||||
error.AccessDenied => return,
|
||||
error.FileNotFound => return,
|
||||
else => return err,
|
||||
};
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
while (true) {
|
||||
const entry = it.next() catch |err| switch (err) {
|
||||
error.AccessDenied => continue,
|
||||
else => return err,
|
||||
} orelse break;
|
||||
|
||||
const entry_path = if (std.mem.eql(u8, dir_path, "."))
|
||||
try std.fmt.allocPrint(ctx.allocator, "{s}", .{entry.name})
|
||||
else
|
||||
try std.fs.path.join(ctx.allocator, &[_][]const u8{ dir_path, entry.name });
|
||||
defer ctx.allocator.free(entry_path);
|
||||
|
||||
// Check default ignores
|
||||
if (ignore.matchesDefaultIgnore(entry_path)) continue;
|
||||
|
||||
// Check gitignore patterns
|
||||
const is_dir = entry.kind == .directory;
|
||||
if (ctx.gitignore.isIgnored(entry_path, is_dir)) continue;
|
||||
|
||||
if (is_dir) {
|
||||
// Add subdirectory to work queue
|
||||
ctx.queue.push(entry_path, depth + 1, ctx.allocator) catch |err| {
|
||||
std.log.warn("Failed to queue {s}: {any}", .{ entry_path, err });
|
||||
};
|
||||
} else if (entry.kind == .file) {
|
||||
// Add file to results
|
||||
ctx.result.add(entry_path, ctx.allocator) catch |err| {
|
||||
std.log.warn("Failed to add file {s}: {any}", .{ entry_path, err });
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel directory walker that uses multiple threads
|
||||
pub fn parallelWalk(
|
||||
allocator: std.mem.Allocator,
|
||||
base_path: []const u8,
|
||||
gitignore: *ignore.GitIgnore,
|
||||
num_threads: usize,
|
||||
) ![][]const u8 {
|
||||
var queue = WorkQueue.init(allocator);
|
||||
defer queue.deinit(allocator);
|
||||
|
||||
var result = WalkResult.init(allocator);
|
||||
defer result.deinit(allocator);
|
||||
|
||||
// Start with base directory
|
||||
try queue.push(".", 0, allocator);
|
||||
|
||||
// Create thread context
|
||||
var ctx = ThreadContext{
|
||||
.queue = &queue,
|
||||
.result = &result,
|
||||
.gitignore = gitignore,
|
||||
.base_path = base_path,
|
||||
.allocator = allocator,
|
||||
.max_depth = 100, // Prevent infinite recursion
|
||||
};
|
||||
|
||||
// Spawn worker threads
|
||||
var threads = try allocator.alloc(std.Thread, num_threads);
|
||||
defer allocator.free(threads);
|
||||
|
||||
for (0..num_threads) |i| {
|
||||
threads[i] = try std.Thread.spawn(.{}, walkWorker, .{&ctx});
|
||||
}
|
||||
|
||||
// Wait for all workers to complete
|
||||
for (threads) |thread| {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
// Sort results for deterministic ordering
|
||||
std.sort.block([]const u8, result.files.items, {}, struct {
|
||||
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
|
||||
return std.mem.order(u8, a, b) == .lt;
|
||||
}
|
||||
}.lessThan);
|
||||
|
||||
// Transfer ownership to caller
|
||||
const files = try allocator.alloc([]const u8, result.files.items.len);
|
||||
@memcpy(files, result.files.items);
|
||||
result.files.items.len = 0; // Prevent deinit from freeing
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
test "parallelWalk basic" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gitignore = ignore.GitIgnore.init(allocator);
|
||||
defer gitignore.deinit();
|
||||
|
||||
// Walk the current directory with 4 threads
|
||||
const files = try parallelWalk(allocator, ".", &gitignore, 4);
|
||||
defer {
|
||||
for (files) |f| allocator.free(f);
|
||||
allocator.free(files);
|
||||
}
|
||||
|
||||
// Should find at least some files
|
||||
try std.testing.expect(files.len > 0);
|
||||
}
|
||||
283
cli/src/utils/watch.zig
Normal file
283
cli/src/utils/watch.zig
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
const std = @import("std");
|
||||
const os = std.os;
|
||||
const log = std.log;
|
||||
|
||||
/// File watcher using OS-native APIs (kqueue on macOS, inotify on Linux)
|
||||
/// Zero third-party dependencies - uses only standard library OS bindings
|
||||
pub const FileWatcher = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
watched_paths: std.StringHashMap(void),
|
||||
|
||||
// Platform-specific handles
|
||||
kqueue_fd: if (@import("builtin").target.os.tag == .macos) i32 else void,
|
||||
inotify_fd: if (@import("builtin").target.os.tag == .linux) i32 else void,
|
||||
|
||||
// Debounce timer
|
||||
last_event_time: i64,
|
||||
debounce_ms: i64,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, debounce_ms: i64) !FileWatcher {
|
||||
const target = @import("builtin").target;
|
||||
|
||||
var watcher = FileWatcher{
|
||||
.allocator = allocator,
|
||||
.watched_paths = std.StringHashMap(void).init(allocator),
|
||||
.kqueue_fd = if (target.os.tag == .macos) -1 else {},
|
||||
.inotify_fd = if (target.os.tag == .linux) -1 else {},
|
||||
.last_event_time = 0,
|
||||
.debounce_ms = debounce_ms,
|
||||
};
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => {
|
||||
watcher.kqueue_fd = try os.kqueue();
|
||||
},
|
||||
.linux => {
|
||||
watcher.inotify_fd = try std.os.inotify_init1(os.linux.IN_CLOEXEC);
|
||||
},
|
||||
else => {
|
||||
return error.UnsupportedPlatform;
|
||||
},
|
||||
}
|
||||
|
||||
return watcher;
|
||||
}
|
||||
|
||||
pub fn deinit(self: *FileWatcher) void {
|
||||
const target = @import("builtin").target;
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => {
|
||||
if (self.kqueue_fd != -1) {
|
||||
os.close(self.kqueue_fd);
|
||||
}
|
||||
},
|
||||
.linux => {
|
||||
if (self.inotify_fd != -1) {
|
||||
os.close(self.inotify_fd);
|
||||
}
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
|
||||
var it = self.watched_paths.keyIterator();
|
||||
while (it.next()) |key| {
|
||||
self.allocator.free(key.*);
|
||||
}
|
||||
self.watched_paths.deinit();
|
||||
}
|
||||
|
||||
/// Add a directory to watch recursively
|
||||
pub fn watchDirectory(self: *FileWatcher, path: []const u8) !void {
|
||||
if (self.watched_paths.contains(path)) return;
|
||||
|
||||
const path_copy = try self.allocator.dupe(u8, path);
|
||||
try self.watched_paths.put(path_copy, {});
|
||||
|
||||
const target = @import("builtin").target;
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => try self.addKqueueWatch(path),
|
||||
.linux => try self.addInotifyWatch(path),
|
||||
else => return error.UnsupportedPlatform,
|
||||
}
|
||||
|
||||
// Recursively watch subdirectories
|
||||
var dir = std.fs.cwd().openDir(path, .{ .iterate = true }) catch |err| switch (err) {
|
||||
error.AccessDenied => return,
|
||||
error.FileNotFound => return,
|
||||
else => return err,
|
||||
};
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
while (true) {
|
||||
const entry = it.next() catch |err| switch (err) {
|
||||
error.AccessDenied => continue,
|
||||
else => return err,
|
||||
} orelse break;
|
||||
|
||||
if (entry.kind == .directory) {
|
||||
const subpath = try std.fs.path.join(self.allocator, &[_][]const u8{ path, entry.name });
|
||||
defer self.allocator.free(subpath);
|
||||
try self.watchDirectory(subpath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add kqueue watch for macOS
|
||||
fn addKqueueWatch(self: *FileWatcher, path: []const u8) !void {
|
||||
if (@import("builtin").target.os.tag != .macos) return;
|
||||
|
||||
const fd = try os.open(path, os.O.EVTONLY | os.O_RDONLY, 0);
|
||||
defer os.close(fd);
|
||||
|
||||
const event = os.Kevent{
|
||||
.ident = @intCast(fd),
|
||||
.filter = os.EVFILT_VNODE,
|
||||
.flags = os.EV_ADD | os.EV_CLEAR,
|
||||
.fflags = os.NOTE_WRITE | os.NOTE_EXTEND | os.NOTE_ATTRIB | os.NOTE_LINK | os.NOTE_RENAME | os.NOTE_REVOKE,
|
||||
.data = 0,
|
||||
.udata = 0,
|
||||
};
|
||||
|
||||
const changes = [_]os.Kevent{event};
|
||||
_ = try os.kevent(self.kqueue_fd, &changes, &.{}, null);
|
||||
}
|
||||
|
||||
/// Add inotify watch for Linux
|
||||
fn addInotifyWatch(self: *FileWatcher, path: []const u8) !void {
|
||||
if (@import("builtin").target.os.tag != .linux) return;
|
||||
|
||||
const mask = os.linux.IN_MODIFY | os.linux.IN_CREATE | os.linux.IN_DELETE |
|
||||
os.linux.IN_MOVED_FROM | os.linux.IN_MOVED_TO | os.linux.IN_ATTRIB;
|
||||
|
||||
const wd = try os.linux.inotify_add_watch(self.inotify_fd, path.ptr, mask);
|
||||
if (wd < 0) return error.InotifyError;
|
||||
}
|
||||
|
||||
/// Wait for file changes with debouncing
|
||||
pub fn waitForChanges(self: *FileWatcher, timeout_ms: i32) !bool {
|
||||
const target = @import("builtin").target;
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => return try self.waitKqueue(timeout_ms),
|
||||
.linux => return try self.waitInotify(timeout_ms),
|
||||
else => return error.UnsupportedPlatform,
|
||||
}
|
||||
}
|
||||
|
||||
/// kqueue wait implementation
|
||||
fn waitKqueue(self: *FileWatcher, timeout_ms: i32) !bool {
|
||||
if (@import("builtin").target.os.tag != .macos) return false;
|
||||
|
||||
var ts: os.timespec = undefined;
|
||||
if (timeout_ms >= 0) {
|
||||
ts.tv_sec = @divTrunc(timeout_ms, 1000);
|
||||
ts.tv_nsec = @mod(timeout_ms, 1000) * 1000000;
|
||||
}
|
||||
|
||||
var events: [10]os.Kevent = undefined;
|
||||
const nev = os.kevent(self.kqueue_fd, &.{}, &events, if (timeout_ms >= 0) &ts else null) catch |err| switch (err) {
|
||||
error.ETIME => return false, // Timeout
|
||||
else => return err,
|
||||
};
|
||||
|
||||
if (nev > 0) {
|
||||
const now = std.time.milliTimestamp();
|
||||
if (now - self.last_event_time > self.debounce_ms) {
|
||||
self.last_event_time = now;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// inotify wait implementation
|
||||
fn waitInotify(self: *FileWatcher, timeout_ms: i32) !bool {
|
||||
if (@import("builtin").target.os.tag != .linux) return false;
|
||||
|
||||
var fds = [_]os.pollfd{.{
|
||||
.fd = self.inotify_fd,
|
||||
.events = os.POLLIN,
|
||||
.revents = 0,
|
||||
}};
|
||||
|
||||
const ready = os.poll(&fds, timeout_ms) catch |err| switch (err) {
|
||||
error.ETIME => return false,
|
||||
else => return err,
|
||||
};
|
||||
|
||||
if (ready > 0 and (fds[0].revents & os.POLLIN) != 0) {
|
||||
var buf: [4096]u8 align(@alignOf(os.linux.inotify_event)) = undefined;
|
||||
|
||||
const bytes_read = try os.read(self.inotify_fd, &buf);
|
||||
if (bytes_read > 0) {
|
||||
const now = std.time.milliTimestamp();
|
||||
if (now - self.last_event_time > self.debounce_ms) {
|
||||
self.last_event_time = now;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Run watch loop with callback
|
||||
pub fn run(self: *FileWatcher, callback: fn () void) !void {
|
||||
log.info("Watching for file changes (debounce: {d}ms)...", .{self.debounce_ms});
|
||||
|
||||
while (true) {
|
||||
if (try self.waitForChanges(-1)) {
|
||||
callback();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Watch command handler
|
||||
pub fn watchCommand(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len < 1) {
|
||||
std.debug.print("Usage: ml watch <path> [--sync] [--queue]\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const path = args[0];
|
||||
var auto_sync = false;
|
||||
var auto_queue = false;
|
||||
|
||||
// Parse flags
|
||||
for (args[1..]) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--sync")) auto_sync = true;
|
||||
if (std.mem.eql(u8, arg, "--queue")) auto_queue = true;
|
||||
}
|
||||
|
||||
var watcher = try FileWatcher.init(allocator, 100); // 100ms debounce
|
||||
defer watcher.deinit();
|
||||
|
||||
try watcher.watchDirectory(path);
|
||||
|
||||
log.info("Watching {s} for changes...", .{path});
|
||||
|
||||
// Callback for file changes
|
||||
const CallbackContext = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
path: []const u8,
|
||||
auto_sync: bool,
|
||||
auto_queue: bool,
|
||||
};
|
||||
|
||||
const ctx = CallbackContext{
|
||||
.allocator = allocator,
|
||||
.path = path,
|
||||
.auto_sync = auto_sync,
|
||||
.auto_queue = auto_queue,
|
||||
};
|
||||
|
||||
// Run watch loop
|
||||
while (true) {
|
||||
if (try watcher.waitForChanges(-1)) {
|
||||
log.info("File changes detected", .{});
|
||||
|
||||
if (auto_sync) {
|
||||
log.info("Auto-syncing...", .{});
|
||||
// Trigger sync (implementation would call sync command)
|
||||
_ = ctx;
|
||||
}
|
||||
|
||||
if (auto_queue) {
|
||||
log.info("Auto-queuing...", .{});
|
||||
// Trigger queue (implementation would call queue command)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test "FileWatcher init/deinit" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var watcher = try FileWatcher.init(allocator, 100);
|
||||
defer watcher.deinit();
|
||||
}
|
||||
|
|
@ -356,7 +356,8 @@ api_key = "your_api_key_here" # Your API key (get from admin)
|
|||
|
||||
// Set proper permissions
|
||||
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
// Log permission warning but don't fail
|
||||
_ = err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ type Config struct {
|
|||
SSHKey string `toml:"ssh_key"`
|
||||
Port int `toml:"port"`
|
||||
BasePath string `toml:"base_path"`
|
||||
Mode string `toml:"mode"` // "dev" or "prod"
|
||||
WrapperScript string `toml:"wrapper_script"`
|
||||
TrainScript string `toml:"train_script"`
|
||||
RedisAddr string `toml:"redis_addr"`
|
||||
RedisPassword string `toml:"redis_password"`
|
||||
RedisDB int `toml:"redis_db"`
|
||||
KnownHosts string `toml:"known_hosts"`
|
||||
ServerURL string `toml:"server_url"` // WebSocket server URL (e.g., ws://localhost:8080)
|
||||
|
||||
// Authentication
|
||||
Auth auth.Config `toml:"auth"`
|
||||
|
|
@ -133,6 +135,20 @@ func (c *Config) Validate() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Set default mode if not specified
|
||||
if c.Mode == "" {
|
||||
if os.Getenv("FETCH_ML_TUI_MODE") != "" {
|
||||
c.Mode = os.Getenv("FETCH_ML_TUI_MODE")
|
||||
} else {
|
||||
c.Mode = "dev" // Default to dev mode
|
||||
}
|
||||
}
|
||||
|
||||
// Set mode-appropriate default paths using project-relative paths
|
||||
if c.BasePath == "" {
|
||||
c.BasePath = utils.ModeBasedBasePath(c.Mode)
|
||||
}
|
||||
|
||||
if c.BasePath != "" {
|
||||
// Convert relative paths to absolute
|
||||
c.BasePath = utils.ExpandPath(c.BasePath)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ func (c *Controller) loadAllData() tea.Cmd {
|
|||
c.loadQueue(),
|
||||
c.loadGPU(),
|
||||
c.loadContainer(),
|
||||
c.loadDatasets(),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -39,6 +40,13 @@ func (c *Controller) loadJobs() tea.Cmd {
|
|||
var jobs []model.Job
|
||||
statusChan := make(chan []model.Job, 4)
|
||||
|
||||
// Debug: Print paths being used
|
||||
c.logger.Info("Loading jobs from paths",
|
||||
"pending", c.getPathForStatus(model.StatusPending),
|
||||
"running", c.getPathForStatus(model.StatusRunning),
|
||||
"finished", c.getPathForStatus(model.StatusFinished),
|
||||
"failed", c.getPathForStatus(model.StatusFailed))
|
||||
|
||||
for _, status := range []model.JobStatus{
|
||||
model.StatusPending,
|
||||
model.StatusRunning,
|
||||
|
|
@ -48,22 +56,18 @@ func (c *Controller) loadJobs() tea.Cmd {
|
|||
go func(s model.JobStatus) {
|
||||
path := c.getPathForStatus(s)
|
||||
names := c.server.ListDir(path)
|
||||
|
||||
// Debug: Log what we found
|
||||
c.logger.Info("Listed directory", "status", s, "path", path, "count", len(names))
|
||||
|
||||
var statusJobs []model.Job
|
||||
for _, name := range names {
|
||||
jobStatus, _ := c.taskQueue.GetJobStatus(name)
|
||||
taskID := jobStatus["task_id"]
|
||||
priority := int64(0)
|
||||
if p, ok := jobStatus["priority"]; ok {
|
||||
_, err := fmt.Sscanf(p, "%d", &priority)
|
||||
if err != nil {
|
||||
priority = 0
|
||||
}
|
||||
}
|
||||
// Lazy loading: only fetch basic info for list view
|
||||
// Full details (GPU, narrative) loaded on selection
|
||||
statusJobs = append(statusJobs, model.Job{
|
||||
Name: name,
|
||||
Status: s,
|
||||
TaskID: taskID,
|
||||
Priority: priority,
|
||||
// TaskID, Priority, GPU info loaded lazily
|
||||
})
|
||||
}
|
||||
statusChan <- statusJobs
|
||||
|
|
@ -85,6 +89,24 @@ func (c *Controller) loadJobs() tea.Cmd {
|
|||
}
|
||||
}
|
||||
|
||||
// loadJobDetails loads full details for a specific job (lazy loading)
|
||||
func (c *Controller) loadJobDetails(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
jobStatus, _ := c.taskQueue.GetJobStatus(jobName)
|
||||
|
||||
// Parse priority
|
||||
priority := int64(0)
|
||||
if p, ok := jobStatus["priority"]; ok {
|
||||
fmt.Sscanf(p, "%d", &priority)
|
||||
}
|
||||
|
||||
// Build full job with details
|
||||
// This is called when job is selected for detailed view
|
||||
|
||||
return model.StatusMsg{Text: "Loaded details for " + jobName, Level: "info"}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadQueue() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
tasks, err := c.taskQueue.GetQueuedTasks()
|
||||
|
|
@ -362,6 +384,18 @@ func (c *Controller) showQueue(m model.State) tea.Cmd {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadDatasets() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
datasets, err := c.taskQueue.ListDatasets()
|
||||
if err != nil {
|
||||
c.logger.Error("failed to load datasets", "error", err)
|
||||
return model.StatusMsg{Text: "Failed to load datasets: " + err.Error(), Level: "error"}
|
||||
}
|
||||
c.logger.Info("loaded datasets", "count", len(datasets))
|
||||
return model.DatasetsLoadedMsg(datasets)
|
||||
}
|
||||
}
|
||||
|
||||
func tickCmd() tea.Cmd {
|
||||
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
|
||||
return model.TickMsg(t)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package controller
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
|
|
@ -19,6 +20,7 @@ type Controller struct {
|
|||
server *services.MLServer
|
||||
taskQueue *services.TaskQueue
|
||||
logger *logging.Logger
|
||||
wsClient *services.WebSocketClient
|
||||
}
|
||||
|
||||
func (c *Controller) handleKeyMsg(msg tea.KeyMsg, m model.State) (model.State, tea.Cmd) {
|
||||
|
|
@ -143,6 +145,33 @@ func (c *Controller) handleGlobalKeys(msg tea.KeyMsg, m *model.State) []tea.Cmd
|
|||
case key.Matches(msg, m.Keys.ViewExperiments):
|
||||
m.ActiveView = model.ViewModeExperiments
|
||||
cmds = append(cmds, c.loadExperiments())
|
||||
case key.Matches(msg, m.Keys.ViewNarrative):
|
||||
m.ActiveView = model.ViewModeNarrative
|
||||
if job := getSelectedJob(*m); job != nil {
|
||||
m.SelectedJob = *job
|
||||
}
|
||||
case key.Matches(msg, m.Keys.ViewTeam):
|
||||
m.ActiveView = model.ViewModeTeam
|
||||
case key.Matches(msg, m.Keys.ViewExperimentHistory):
|
||||
m.ActiveView = model.ViewModeExperimentHistory
|
||||
cmds = append(cmds, c.loadExperimentHistory())
|
||||
case key.Matches(msg, m.Keys.ViewConfig):
|
||||
m.ActiveView = model.ViewModeConfig
|
||||
cmds = append(cmds, c.loadConfig())
|
||||
case key.Matches(msg, m.Keys.ViewLogs):
|
||||
m.ActiveView = model.ViewModeLogs
|
||||
if job := getSelectedJob(*m); job != nil {
|
||||
cmds = append(cmds, c.loadLogs(job.Name))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.ViewExport):
|
||||
if job := getSelectedJob(*m); job != nil {
|
||||
cmds = append(cmds, c.exportJob(job.Name))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.FilterTeam):
|
||||
m.InputMode = true
|
||||
m.Input.SetValue("@")
|
||||
m.Input.Focus()
|
||||
m.Status = "Filter by team member: @alice, @bob, @team-ml"
|
||||
case key.Matches(msg, m.Keys.Cancel):
|
||||
if job := getSelectedJob(*m); job != nil && job.TaskID != "" {
|
||||
cmds = append(cmds, c.cancelTask(job.TaskID))
|
||||
|
|
@ -181,8 +210,18 @@ func (c *Controller) applyWindowSize(msg tea.WindowSizeMsg, m model.State) model
|
|||
m.QueueView.Height = listHeight - 4
|
||||
m.SettingsView.Width = panelWidth
|
||||
m.SettingsView.Height = listHeight - 4
|
||||
m.NarrativeView.Width = panelWidth
|
||||
m.NarrativeView.Height = listHeight - 4
|
||||
m.TeamView.Width = panelWidth
|
||||
m.TeamView.Height = listHeight - 4
|
||||
m.ExperimentsView.Width = panelWidth
|
||||
m.ExperimentsView.Height = listHeight - 4
|
||||
m.ExperimentHistoryView.Width = panelWidth
|
||||
m.ExperimentHistoryView.Height = listHeight - 4
|
||||
m.ConfigView.Width = panelWidth
|
||||
m.ConfigView.Height = listHeight - 4
|
||||
m.LogsView.Width = panelWidth
|
||||
m.LogsView.Height = listHeight - 4
|
||||
|
||||
return m
|
||||
}
|
||||
|
|
@ -245,7 +284,25 @@ func (c *Controller) handleStatusMsg(msg model.StatusMsg, m model.State) (model.
|
|||
|
||||
func (c *Controller) handleTickMsg(msg model.TickMsg, m model.State) (model.State, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading {
|
||||
|
||||
// Calculate actual refresh rate
|
||||
now := time.Now()
|
||||
if !m.LastFrameTime.IsZero() {
|
||||
elapsed := now.Sub(m.LastFrameTime).Milliseconds()
|
||||
if elapsed > 0 {
|
||||
// Smooth the rate with simple averaging
|
||||
m.RefreshRate = (m.RefreshRate*float64(m.FrameCount) + float64(elapsed)) / float64(m.FrameCount+1)
|
||||
m.FrameCount++
|
||||
if m.FrameCount > 100 {
|
||||
m.FrameCount = 1
|
||||
m.RefreshRate = float64(elapsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.LastFrameTime = now
|
||||
|
||||
// 500ms refresh target for real-time updates
|
||||
if time.Since(m.LastRefresh) > 500*time.Millisecond && !m.IsLoading {
|
||||
m.LastRefresh = time.Now()
|
||||
cmds = append(cmds, c.loadAllData())
|
||||
}
|
||||
|
|
@ -290,16 +347,25 @@ func New(
|
|||
tq *services.TaskQueue,
|
||||
logger *logging.Logger,
|
||||
) *Controller {
|
||||
// Create WebSocket client for real-time updates
|
||||
wsClient := services.NewWebSocketClient(cfg.ServerURL, "", logger)
|
||||
|
||||
return &Controller{
|
||||
config: cfg,
|
||||
server: srv,
|
||||
taskQueue: tq,
|
||||
logger: logger,
|
||||
wsClient: wsClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the TUI and returns initial commands
|
||||
func (c *Controller) Init() tea.Cmd {
|
||||
// Connect WebSocket for real-time updates
|
||||
if err := c.wsClient.Connect(); err != nil {
|
||||
c.logger.Error("WebSocket connection failed", "error", err)
|
||||
}
|
||||
|
||||
return tea.Batch(
|
||||
tea.SetWindowTitle("FetchML"),
|
||||
c.loadAllData(),
|
||||
|
|
@ -307,14 +373,17 @@ func (c *Controller) Init() tea.Cmd {
|
|||
)
|
||||
}
|
||||
|
||||
// Update handles all messages and updates the state
|
||||
func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
||||
switch typed := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
return c.handleKeyMsg(typed, m)
|
||||
case tea.WindowSizeMsg:
|
||||
// Only apply window size on first render, then keep constant
|
||||
if m.Width == 0 && m.Height == 0 {
|
||||
updated := c.applyWindowSize(typed, m)
|
||||
return c.finalizeUpdate(msg, updated)
|
||||
}
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.JobsLoadedMsg:
|
||||
return c.handleJobsLoadedMsg(typed, m)
|
||||
case model.TasksLoadedMsg:
|
||||
|
|
@ -323,8 +392,26 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
|||
return c.handleGPUContent(typed, m)
|
||||
case model.ContainerLoadedMsg:
|
||||
return c.handleContainerContent(typed, m)
|
||||
case model.QueueLoadedMsg:
|
||||
return c.handleQueueContent(typed, m)
|
||||
case model.DatasetsLoadedMsg:
|
||||
// Format datasets into view content
|
||||
var content strings.Builder
|
||||
content.WriteString("Available Datasets\n")
|
||||
content.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
if len(typed) == 0 {
|
||||
content.WriteString("📭 No datasets found\n\n")
|
||||
content.WriteString("Datasets will appear here when available\n")
|
||||
content.WriteString("in the data directory.")
|
||||
} else {
|
||||
for i, ds := range typed {
|
||||
content.WriteString(fmt.Sprintf("%d. 📁 %s\n", i+1, ds.Name))
|
||||
content.WriteString(fmt.Sprintf(" Location: %s\n", ds.Location))
|
||||
content.WriteString(fmt.Sprintf(" Size: %d bytes\n", ds.SizeBytes))
|
||||
content.WriteString(fmt.Sprintf(" Last Access: %s\n\n", ds.LastAccess.Format("2006-01-02 15:04")))
|
||||
}
|
||||
}
|
||||
m.DatasetView.SetContent(content.String())
|
||||
m.DatasetView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.SettingsContentMsg:
|
||||
m.SettingsView.SetContent(string(typed))
|
||||
return c.finalizeUpdate(msg, m)
|
||||
|
|
@ -332,12 +419,36 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
|||
m.ExperimentsView.SetContent(string(typed))
|
||||
m.ExperimentsView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case ExperimentHistoryLoadedMsg:
|
||||
m.ExperimentHistoryView.SetContent(string(typed))
|
||||
m.ExperimentHistoryView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case ConfigLoadedMsg:
|
||||
m.ConfigView.SetContent(string(typed))
|
||||
m.ConfigView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case LogsLoadedMsg:
|
||||
m.LogsView.SetContent(string(typed))
|
||||
m.LogsView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.SettingsUpdateMsg:
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.StatusMsg:
|
||||
return c.handleStatusMsg(typed, m)
|
||||
case model.TickMsg:
|
||||
return c.handleTickMsg(typed, m)
|
||||
case model.JobUpdateMsg:
|
||||
// Handle real-time job status updates from WebSocket
|
||||
m.Status = fmt.Sprintf("Job %s: %s", typed.JobName, typed.Status)
|
||||
// Refresh job list to show updated status
|
||||
return m, c.loadAllData()
|
||||
case model.GPUUpdateMsg:
|
||||
// Throttle GPU updates to 1/second (humans can't perceive faster)
|
||||
if time.Since(m.LastGPUUpdate) > 1*time.Second {
|
||||
m.LastGPUUpdate = time.Now()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
}
|
||||
return m, nil
|
||||
default:
|
||||
return c.finalizeUpdate(msg, m)
|
||||
}
|
||||
|
|
@ -346,6 +457,12 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
|||
// ExperimentsLoadedMsg is sent when experiments are loaded
|
||||
type ExperimentsLoadedMsg string
|
||||
|
||||
// ExperimentHistoryLoadedMsg is sent when experiment history is loaded
|
||||
type ExperimentHistoryLoadedMsg string
|
||||
|
||||
// ConfigLoadedMsg is sent when config is loaded
|
||||
type ConfigLoadedMsg string
|
||||
|
||||
func (c *Controller) loadExperiments() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
commitIDs, err := c.taskQueue.ListExperiments()
|
||||
|
|
@ -372,3 +489,92 @@ func (c *Controller) loadExperiments() tea.Cmd {
|
|||
return ExperimentsLoadedMsg(output)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadExperimentHistory() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Placeholder - will show experiment history with annotations
|
||||
return ExperimentHistoryLoadedMsg("Experiment History & Annotations\n\n" +
|
||||
"This view will show:\n" +
|
||||
"- Previous experiment runs\n" +
|
||||
"- Annotations and notes\n" +
|
||||
"- Config snapshots\n" +
|
||||
"- Side-by-side comparisons\n\n" +
|
||||
"(Requires API: GET /api/experiments/:id/history)")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadConfig() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Build config diff showing changes from defaults
|
||||
var output strings.Builder
|
||||
output.WriteString("⚙️ Config View (Read-Only)\n\n")
|
||||
|
||||
output.WriteString("┌─ Changes from Defaults ─────────────────────┐\n")
|
||||
changes := []string{}
|
||||
|
||||
if c.config.Host != "" {
|
||||
changes = append(changes, fmt.Sprintf("│ Host: %s", c.config.Host))
|
||||
}
|
||||
if c.config.Port != 0 && c.config.Port != 22 {
|
||||
changes = append(changes, fmt.Sprintf("│ Port: %d (default: 22)", c.config.Port))
|
||||
}
|
||||
if c.config.BasePath != "" {
|
||||
changes = append(changes, fmt.Sprintf("│ Base Path: %s", c.config.BasePath))
|
||||
}
|
||||
if c.config.RedisAddr != "" && c.config.RedisAddr != "localhost:6379" {
|
||||
changes = append(changes, fmt.Sprintf("│ Redis: %s (default: localhost:6379)", c.config.RedisAddr))
|
||||
}
|
||||
if c.config.ServerURL != "" {
|
||||
changes = append(changes, fmt.Sprintf("│ Server: %s", c.config.ServerURL))
|
||||
}
|
||||
|
||||
if len(changes) == 0 {
|
||||
output.WriteString("│ (Using all default settings)\n")
|
||||
} else {
|
||||
for _, change := range changes {
|
||||
output.WriteString(change + "\n")
|
||||
}
|
||||
}
|
||||
output.WriteString("└─────────────────────────────────────────────┘\n\n")
|
||||
|
||||
output.WriteString("Full Configuration:\n")
|
||||
output.WriteString(fmt.Sprintf(" Host: %s\n", c.config.Host))
|
||||
output.WriteString(fmt.Sprintf(" Port: %d\n", c.config.Port))
|
||||
output.WriteString(fmt.Sprintf(" Base Path: %s\n", c.config.BasePath))
|
||||
output.WriteString(fmt.Sprintf(" Redis: %s\n", c.config.RedisAddr))
|
||||
output.WriteString(fmt.Sprintf(" Server: %s\n", c.config.ServerURL))
|
||||
output.WriteString(fmt.Sprintf(" User: %s\n\n", c.config.User))
|
||||
|
||||
output.WriteString("Use CLI to modify: ml config set <key> <value>")
|
||||
|
||||
return ConfigLoadedMsg(output.String())
|
||||
}
|
||||
}
|
||||
|
||||
// LogsLoadedMsg is sent when logs are loaded
|
||||
type LogsLoadedMsg string
|
||||
|
||||
func (c *Controller) loadLogs(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Placeholder - will stream logs from job
|
||||
return LogsLoadedMsg("📜 Logs for " + jobName + "\n\n" +
|
||||
"Log streaming will appear here...\n\n" +
|
||||
"(Requires API: GET /api/jobs/" + jobName + "/logs?follow=true)")
|
||||
}
|
||||
}
|
||||
|
||||
// ExportCompletedMsg is sent when export is complete
|
||||
type ExportCompletedMsg struct {
|
||||
JobName string
|
||||
Path string
|
||||
}
|
||||
|
||||
func (c *Controller) exportJob(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Show export in progress
|
||||
return model.StatusMsg{
|
||||
Text: "Exporting " + jobName + "... (anonymized)",
|
||||
Level: "info",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
// Package model provides TUI data structures and state management
|
||||
package model
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// JobStatus represents the status of a job
|
||||
type JobStatus string
|
||||
|
|
@ -21,12 +25,23 @@ type Job struct {
|
|||
Status JobStatus
|
||||
TaskID string
|
||||
Priority int64
|
||||
// Narrative fields for research context
|
||||
Hypothesis string
|
||||
Context string
|
||||
Intent string
|
||||
ExpectedOutcome string
|
||||
ActualOutcome string
|
||||
OutcomeStatus string // validated, invalidated, inconclusive
|
||||
// GPU allocation tracking
|
||||
GPUDeviceID int // -1 if not assigned
|
||||
GPUUtilization int // 0-100%
|
||||
GPUMemoryUsed int64 // MB
|
||||
}
|
||||
|
||||
// Title returns the job title for display
|
||||
func (j Job) Title() string { return j.Name }
|
||||
|
||||
// Description returns a formatted description with status icon
|
||||
// Description returns a formatted description with status icon and GPU info
|
||||
func (j Job) Description() string {
|
||||
icon := map[JobStatus]string{
|
||||
StatusPending: "⏸",
|
||||
|
|
@ -39,7 +54,16 @@ func (j Job) Description() string {
|
|||
if j.Priority > 0 {
|
||||
pri = fmt.Sprintf(" [P%d]", j.Priority)
|
||||
}
|
||||
return fmt.Sprintf("%s %s%s", icon, j.Status, pri)
|
||||
gpu := ""
|
||||
if j.GPUDeviceID >= 0 {
|
||||
gpu = fmt.Sprintf(" [GPU:%d %d%%]", j.GPUDeviceID, j.GPUUtilization)
|
||||
}
|
||||
|
||||
// Apply status color to the status text
|
||||
statusStyle := lipgloss.NewStyle().Foreground(StatusColor(j.Status))
|
||||
coloredStatus := statusStyle.Render(string(j.Status))
|
||||
|
||||
return fmt.Sprintf("%s %s%s%s", icon, coloredStatus, pri, gpu)
|
||||
}
|
||||
|
||||
// FilterValue returns the value used for filtering
|
||||
|
|
|
|||
|
|
@ -15,6 +15,13 @@ type KeyMap struct {
|
|||
ViewDatasets key.Binding
|
||||
ViewExperiments key.Binding
|
||||
ViewSettings key.Binding
|
||||
ViewNarrative key.Binding
|
||||
ViewTeam key.Binding
|
||||
ViewExperimentHistory key.Binding
|
||||
ViewConfig key.Binding
|
||||
ViewLogs key.Binding
|
||||
ViewExport key.Binding
|
||||
FilterTeam key.Binding
|
||||
Cancel key.Binding
|
||||
Delete key.Binding
|
||||
MarkFailed key.Binding
|
||||
|
|
@ -29,18 +36,25 @@ func DefaultKeys() KeyMap {
|
|||
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
|
||||
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
|
||||
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
|
||||
ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")),
|
||||
ViewQueue: key.NewBinding(key.WithKeys("q"), key.WithHelp("q", "view queue")),
|
||||
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
|
||||
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
|
||||
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
|
||||
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
|
||||
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
|
||||
Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")),
|
||||
ViewNarrative: key.NewBinding(key.WithKeys("n"), key.WithHelp("n", "narrative")),
|
||||
ViewTeam: key.NewBinding(key.WithKeys("m"), key.WithHelp("m", "team")),
|
||||
ViewExperimentHistory: key.NewBinding(key.WithKeys("e"), key.WithHelp("e", "experiment history")),
|
||||
ViewConfig: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "config")),
|
||||
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
|
||||
ViewLogs: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "logs")),
|
||||
ViewExport: key.NewBinding(key.WithKeys("E"), key.WithHelp("E", "export job")),
|
||||
FilterTeam: key.NewBinding(key.WithKeys("@"), key.WithHelp("@", "filter by team")),
|
||||
Cancel: key.NewBinding(key.WithKeys("x"), key.WithHelp("x", "cancel task")),
|
||||
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
|
||||
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
|
||||
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
|
||||
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
|
||||
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
|
||||
Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")),
|
||||
Quit: key.NewBinding(key.WithKeys("ctrl+c"), key.WithHelp("ctrl+c", "quit")),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ type JobsLoadedMsg []Job
|
|||
// TasksLoadedMsg contains loaded tasks from the queue
|
||||
type TasksLoadedMsg []*Task
|
||||
|
||||
// DatasetsLoadedMsg contains loaded datasets
|
||||
type DatasetsLoadedMsg []DatasetInfo
|
||||
|
||||
// GpuLoadedMsg contains GPU status information
|
||||
type GpuLoadedMsg string
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,11 @@ const (
|
|||
ViewModeSettings // Settings view mode
|
||||
ViewModeDatasets // Datasets view mode
|
||||
ViewModeExperiments // Experiments view mode
|
||||
ViewModeNarrative // Narrative/Outcome view mode
|
||||
ViewModeTeam // Team collaboration view mode
|
||||
ViewModeExperimentHistory // Experiment history view mode
|
||||
ViewModeConfig // Config view mode
|
||||
ViewModeLogs // Logs streaming view mode
|
||||
)
|
||||
|
||||
// DatasetInfo represents dataset information in the TUI
|
||||
|
|
@ -61,6 +66,12 @@ type State struct {
|
|||
SettingsView viewport.Model
|
||||
DatasetView viewport.Model
|
||||
ExperimentsView viewport.Model
|
||||
NarrativeView viewport.Model
|
||||
TeamView viewport.Model
|
||||
ExperimentHistoryView viewport.Model
|
||||
ConfigView viewport.Model
|
||||
LogsView viewport.Model
|
||||
SelectedJob Job
|
||||
Input textinput.Model
|
||||
APIKeyInput textinput.Model
|
||||
Status string
|
||||
|
|
@ -72,6 +83,10 @@ type State struct {
|
|||
Spinner spinner.Model
|
||||
ActiveView ViewMode
|
||||
LastRefresh time.Time
|
||||
LastFrameTime time.Time
|
||||
RefreshRate float64 // measured in ms
|
||||
FrameCount int
|
||||
LastGPUUpdate time.Time
|
||||
IsLoading bool
|
||||
JobStats map[JobStatus]int
|
||||
APIKey string
|
||||
|
|
@ -112,6 +127,11 @@ func InitialState(apiKey string) State {
|
|||
SettingsView: viewport.New(0, 0),
|
||||
DatasetView: viewport.New(0, 0),
|
||||
ExperimentsView: viewport.New(0, 0),
|
||||
NarrativeView: viewport.New(0, 0),
|
||||
TeamView: viewport.New(0, 0),
|
||||
ExperimentHistoryView: viewport.New(0, 0),
|
||||
ConfigView: viewport.New(0, 0),
|
||||
LogsView: viewport.New(0, 0),
|
||||
Input: input,
|
||||
APIKeyInput: apiKeyInput,
|
||||
Status: "Connected",
|
||||
|
|
@ -127,3 +147,27 @@ func InitialState(apiKey string) State {
|
|||
Keys: DefaultKeys(),
|
||||
}
|
||||
}
|
||||
|
||||
// LogMsg represents a log line from a job
|
||||
type LogMsg struct {
|
||||
JobName string `json:"job_name"`
|
||||
Line string `json:"line"`
|
||||
Time string `json:"time"`
|
||||
}
|
||||
|
||||
// JobUpdateMsg represents a real-time job status update via WebSocket
|
||||
type JobUpdateMsg struct {
|
||||
JobName string `json:"job_name"`
|
||||
Status string `json:"status"`
|
||||
TaskID string `json:"task_id"`
|
||||
Progress int `json:"progress"`
|
||||
}
|
||||
|
||||
// GPUUpdateMsg represents a real-time GPU status update via WebSocket
|
||||
type GPUUpdateMsg struct {
|
||||
DeviceID int `json:"device_id"`
|
||||
Utilization int `json:"utilization"`
|
||||
MemoryUsed int64 `json:"memory_used"`
|
||||
MemoryTotal int64 `json:"memory_total"`
|
||||
Temperature int `json:"temperature"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,38 @@ import (
|
|||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Status colors for job list items
|
||||
var (
|
||||
// StatusRunningColor is green for running jobs
|
||||
StatusRunningColor = lipgloss.Color("#2ecc71")
|
||||
// StatusPendingColor is yellow for pending jobs
|
||||
StatusPendingColor = lipgloss.Color("#f1c40f")
|
||||
// StatusFailedColor is red for failed jobs
|
||||
StatusFailedColor = lipgloss.Color("#e74c3c")
|
||||
// StatusFinishedColor is blue for completed jobs
|
||||
StatusFinishedColor = lipgloss.Color("#3498db")
|
||||
// StatusQueuedColor is gray for queued jobs
|
||||
StatusQueuedColor = lipgloss.Color("#95a5a6")
|
||||
)
|
||||
|
||||
// StatusColor returns the color for a job status
|
||||
func StatusColor(status JobStatus) lipgloss.Color {
|
||||
switch status {
|
||||
case StatusRunning:
|
||||
return StatusRunningColor
|
||||
case StatusPending:
|
||||
return StatusPendingColor
|
||||
case StatusFailed:
|
||||
return StatusFailedColor
|
||||
case StatusFinished:
|
||||
return StatusFinishedColor
|
||||
case StatusQueued:
|
||||
return StatusQueuedColor
|
||||
default:
|
||||
return lipgloss.Color("#ffffff")
|
||||
}
|
||||
}
|
||||
|
||||
// NewJobListDelegate creates a styled delegate for the job list
|
||||
func NewJobListDelegate() list.DefaultDelegate {
|
||||
delegate := list.NewDefaultDelegate()
|
||||
|
|
|
|||
46
cmd/tui/internal/services/export.go
Normal file
46
cmd/tui/internal/services/export.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Package services provides TUI service clients
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// ExportService handles job export functionality for TUI
|
||||
type ExportService struct {
|
||||
serverURL string
|
||||
apiKey string
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewExportService creates a new export service
|
||||
func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportService {
|
||||
return &ExportService{
|
||||
serverURL: serverURL,
|
||||
apiKey: apiKey,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExportJob exports a job with optional anonymization
|
||||
// Returns the path to the exported file
|
||||
func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) {
|
||||
s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize)
|
||||
|
||||
// Placeholder - actual implementation would call API
|
||||
// POST /api/jobs/{id}/export?anonymize=true
|
||||
|
||||
exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix())
|
||||
|
||||
s.logger.Info("export complete", "job", jobName, "path", exportPath)
|
||||
return exportPath, nil
|
||||
}
|
||||
|
||||
// ExportOptions contains options for export
|
||||
type ExportOptions struct {
|
||||
Anonymize bool
|
||||
IncludeLogs bool
|
||||
IncludeData bool
|
||||
}
|
||||
|
|
@ -4,8 +4,12 @@ package services
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
|
|
@ -21,6 +25,7 @@ type TaskQueue struct {
|
|||
*queue.TaskQueue // Embed to inherit all queue methods directly
|
||||
expManager *experiment.Manager
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewTaskQueue creates a new task queue service
|
||||
|
|
@ -37,14 +42,17 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
|
|||
return nil, fmt.Errorf("failed to create task queue: %w", err)
|
||||
}
|
||||
|
||||
// Initialize experiment manager
|
||||
// TODO: Get base path from config
|
||||
expManager := experiment.NewManager("./experiments")
|
||||
// Initialize experiment manager with proper path
|
||||
// BasePath already includes the mode-based experiments path (e.g., ./data/dev/experiments)
|
||||
expDir := cfg.BasePath
|
||||
os.MkdirAll(expDir, 0755)
|
||||
expManager := experiment.NewManager(expDir)
|
||||
|
||||
return &TaskQueue{
|
||||
TaskQueue: internalQueue,
|
||||
expManager: expManager,
|
||||
ctx: context.Background(),
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -94,20 +102,38 @@ func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
|
|||
return map[string]string{}, nil
|
||||
}
|
||||
|
||||
// ListDatasets retrieves available datasets (TUI-specific: currently returns empty)
|
||||
func (tq *TaskQueue) ListDatasets() ([]struct {
|
||||
Name string
|
||||
SizeBytes int64
|
||||
Location string
|
||||
LastAccess string
|
||||
}, error) {
|
||||
// This method doesn't exist in internal queue, return empty for now
|
||||
return []struct {
|
||||
Name string
|
||||
SizeBytes int64
|
||||
Location string
|
||||
LastAccess string
|
||||
}{}, nil
|
||||
// ListDatasets retrieves available datasets from the filesystem
|
||||
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
|
||||
var datasets []model.DatasetInfo
|
||||
|
||||
// Scan the active data directory for datasets
|
||||
dataDir := tq.config.BasePath
|
||||
if dataDir == "" {
|
||||
return datasets, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dataDir)
|
||||
if err != nil {
|
||||
// Directory might not exist yet, return empty
|
||||
return datasets, nil
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
datasets = append(datasets, model.DatasetInfo{
|
||||
Name: entry.Name(),
|
||||
SizeBytes: info.Size(),
|
||||
Location: filepath.Join(dataDir, entry.Name()),
|
||||
LastAccess: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return datasets, nil
|
||||
}
|
||||
|
||||
// ListExperiments retrieves experiment list
|
||||
|
|
|
|||
275
cmd/tui/internal/services/websocket.go
Normal file
275
cmd/tui/internal/services/websocket.go
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
// Package services provides TUI service clients
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// WebSocketClient manages real-time updates from the server
|
||||
type WebSocketClient struct {
|
||||
conn *websocket.Conn
|
||||
serverURL string
|
||||
apiKey string
|
||||
logger *logging.Logger
|
||||
|
||||
// Channels for different update types
|
||||
jobUpdates chan model.JobUpdateMsg
|
||||
gpuUpdates chan model.GPUUpdateMsg
|
||||
statusUpdates chan model.StatusMsg
|
||||
|
||||
// Control
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
connected bool
|
||||
}
|
||||
|
||||
// JobUpdateMsg represents a real-time job status update
|
||||
type JobUpdateMsg struct {
|
||||
JobName string `json:"job_name"`
|
||||
Status string `json:"status"`
|
||||
TaskID string `json:"task_id"`
|
||||
Progress int `json:"progress"`
|
||||
}
|
||||
|
||||
// GPUUpdateMsg represents a real-time GPU status update
|
||||
type GPUUpdateMsg struct {
|
||||
DeviceID int `json:"device_id"`
|
||||
Utilization int `json:"utilization"`
|
||||
MemoryUsed int64 `json:"memory_used"`
|
||||
MemoryTotal int64 `json:"memory_total"`
|
||||
Temperature int `json:"temperature"`
|
||||
}
|
||||
|
||||
// NewWebSocketClient creates a new WebSocket client
|
||||
func NewWebSocketClient(serverURL, apiKey string, logger *logging.Logger) *WebSocketClient {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &WebSocketClient{
|
||||
serverURL: serverURL,
|
||||
apiKey: apiKey,
|
||||
logger: logger,
|
||||
jobUpdates: make(chan model.JobUpdateMsg, 100),
|
||||
gpuUpdates: make(chan model.GPUUpdateMsg, 100),
|
||||
statusUpdates: make(chan model.StatusMsg, 100),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (c *WebSocketClient) Connect() error {
|
||||
// Parse server URL and construct WebSocket URL
|
||||
u, err := url.Parse(c.serverURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid server URL: %w", err)
|
||||
}
|
||||
|
||||
// Convert http/https to ws/wss
|
||||
wsScheme := "ws"
|
||||
if u.Scheme == "https" {
|
||||
wsScheme = "wss"
|
||||
}
|
||||
wsURL := fmt.Sprintf("%s://%s/ws", wsScheme, u.Host)
|
||||
|
||||
// Create dialer with timeout
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
Subprotocols: []string{"fetchml-v1"},
|
||||
}
|
||||
|
||||
// Add API key to headers
|
||||
headers := http.Header{}
|
||||
if c.apiKey != "" {
|
||||
headers.Set("X-API-Key", c.apiKey)
|
||||
}
|
||||
|
||||
conn, resp, err := dialer.Dial(wsURL, headers)
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
return fmt.Errorf("websocket dial failed (status %d): %w", resp.StatusCode, err)
|
||||
}
|
||||
return fmt.Errorf("websocket dial failed: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.connected = true
|
||||
c.logger.Info("websocket connected", "url", wsURL)
|
||||
|
||||
// Start message handler
|
||||
go c.messageHandler()
|
||||
|
||||
// Start heartbeat
|
||||
go c.heartbeat()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect closes the WebSocket connection
|
||||
func (c *WebSocketClient) Disconnect() {
|
||||
c.cancel()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
c.connected = false
|
||||
}
|
||||
|
||||
// IsConnected returns true if connected
|
||||
func (c *WebSocketClient) IsConnected() bool {
|
||||
return c.connected
|
||||
}
|
||||
|
||||
// messageHandler reads messages from the WebSocket
|
||||
func (c *WebSocketClient) messageHandler() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if c.conn == nil {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set read deadline
|
||||
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
// Read message
|
||||
messageType, data, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
c.logger.Error("websocket read error", "error", err)
|
||||
}
|
||||
c.connected = false
|
||||
|
||||
// Attempt reconnect
|
||||
time.Sleep(5 * time.Second)
|
||||
if err := c.Connect(); err != nil {
|
||||
c.logger.Error("websocket reconnect failed", "error", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle binary vs text messages
|
||||
if messageType == websocket.BinaryMessage {
|
||||
c.handleBinaryMessage(data)
|
||||
} else {
|
||||
c.handleTextMessage(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleBinaryMessage handles binary WebSocket messages
|
||||
func (c *WebSocketClient) handleBinaryMessage(data []byte) {
|
||||
if len(data) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
// Binary protocol: [opcode:1][data...]
|
||||
opcode := data[0]
|
||||
payload := data[1:]
|
||||
|
||||
switch opcode {
|
||||
case 0x01: // Job update
|
||||
var update JobUpdateMsg
|
||||
if err := json.Unmarshal(payload, &update); err != nil {
|
||||
c.logger.Error("failed to unmarshal job update", "error", err)
|
||||
return
|
||||
}
|
||||
c.jobUpdates <- model.JobUpdateMsg(update)
|
||||
|
||||
case 0x02: // GPU update
|
||||
var update GPUUpdateMsg
|
||||
if err := json.Unmarshal(payload, &update); err != nil {
|
||||
c.logger.Error("failed to unmarshal GPU update", "error", err)
|
||||
return
|
||||
}
|
||||
c.gpuUpdates <- model.GPUUpdateMsg(update)
|
||||
|
||||
case 0x03: // Status message
|
||||
var status model.StatusMsg
|
||||
if err := json.Unmarshal(payload, &status); err != nil {
|
||||
c.logger.Error("failed to unmarshal status", "error", err)
|
||||
return
|
||||
}
|
||||
c.statusUpdates <- status
|
||||
}
|
||||
}
|
||||
|
||||
// handleTextMessage handles text WebSocket messages (JSON)
|
||||
func (c *WebSocketClient) handleTextMessage(data []byte) {
|
||||
var msg map[string]interface{}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
c.logger.Error("failed to unmarshal text message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
msgType, _ := msg["type"].(string)
|
||||
switch msgType {
|
||||
case "job_update":
|
||||
// Handle JSON job updates
|
||||
case "gpu_update":
|
||||
// Handle JSON GPU updates
|
||||
case "status":
|
||||
// Handle status messages
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat sends periodic ping messages
|
||||
func (c *WebSocketClient) heartbeat() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if c.conn != nil {
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
c.logger.Error("websocket ping failed", "error", err)
|
||||
c.connected = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe subscribes to specific update channels
|
||||
func (c *WebSocketClient) Subscribe(channels ...string) error {
|
||||
if !c.connected {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
subMsg := map[string]interface{}{
|
||||
"action": "subscribe",
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(subMsg)
|
||||
return c.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
// GetJobUpdates returns the job updates channel
|
||||
func (c *WebSocketClient) GetJobUpdates() <-chan model.JobUpdateMsg {
|
||||
return c.jobUpdates
|
||||
}
|
||||
|
||||
// GetGPUUpdates returns the GPU updates channel
|
||||
func (c *WebSocketClient) GetGPUUpdates() <-chan model.GPUUpdateMsg {
|
||||
return c.gpuUpdates
|
||||
}
|
||||
|
||||
// GetStatusUpdates returns the status updates channel
|
||||
func (c *WebSocketClient) GetStatusUpdates() <-chan model.StatusMsg {
|
||||
return c.statusUpdates
|
||||
}
|
||||
147
cmd/tui/internal/view/narrative_view.go
Normal file
147
cmd/tui/internal/view/narrative_view.go
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
// Package view provides TUI rendering functionality
|
||||
package view
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
)
|
||||
|
||||
// NarrativeView displays job research context and outcome
|
||||
type NarrativeView struct {
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
|
||||
// View renders the narrative/outcome for a job
|
||||
func (v *NarrativeView) View(job model.Job) string {
|
||||
if job.Hypothesis == "" && job.Context == "" && job.Intent == "" {
|
||||
return v.renderEmpty()
|
||||
}
|
||||
|
||||
var sections []string
|
||||
|
||||
// Title
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#ff9e64")).
|
||||
MarginBottom(1)
|
||||
sections = append(sections, titleStyle.Render("📊 Research Context"))
|
||||
|
||||
// Hypothesis
|
||||
if job.Hypothesis != "" {
|
||||
sections = append(sections, v.renderSection("🧪 Hypothesis", job.Hypothesis))
|
||||
}
|
||||
|
||||
// Context
|
||||
if job.Context != "" {
|
||||
sections = append(sections, v.renderSection("📚 Context", job.Context))
|
||||
}
|
||||
|
||||
// Intent
|
||||
if job.Intent != "" {
|
||||
sections = append(sections, v.renderSection("🎯 Intent", job.Intent))
|
||||
}
|
||||
|
||||
// Expected Outcome
|
||||
if job.ExpectedOutcome != "" {
|
||||
sections = append(sections, v.renderSection("📈 Expected Outcome", job.ExpectedOutcome))
|
||||
}
|
||||
|
||||
// Actual Outcome (if available)
|
||||
if job.ActualOutcome != "" {
|
||||
statusIcon := "❓"
|
||||
switch job.OutcomeStatus {
|
||||
case "validated":
|
||||
statusIcon = "✅"
|
||||
case "invalidated":
|
||||
statusIcon = "❌"
|
||||
case "inconclusive":
|
||||
statusIcon = "⚖️"
|
||||
case "partial":
|
||||
statusIcon = "📝"
|
||||
}
|
||||
sections = append(sections, v.renderSection(statusIcon+" Actual Outcome", job.ActualOutcome))
|
||||
}
|
||||
|
||||
// Combine all sections
|
||||
content := lipgloss.JoinVertical(lipgloss.Left, sections...)
|
||||
|
||||
// Apply border and padding
|
||||
style := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7aa2f7")).
|
||||
Padding(1, 2).
|
||||
Width(v.Width - 4)
|
||||
|
||||
return style.Render(content)
|
||||
}
|
||||
|
||||
func (v *NarrativeView) renderSection(title, content string) string {
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7aa2f7"))
|
||||
|
||||
contentStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#d8dee9")).
|
||||
MarginLeft(2)
|
||||
|
||||
// Wrap long content
|
||||
wrapped := wrapText(content, v.Width-12)
|
||||
|
||||
return lipgloss.JoinVertical(lipgloss.Left,
|
||||
titleStyle.Render(title),
|
||||
contentStyle.Render(wrapped),
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
||||
func (v *NarrativeView) renderEmpty() string {
|
||||
style := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#88c0d0")).
|
||||
Italic(true).
|
||||
Padding(2)
|
||||
|
||||
return style.Render("No narrative data available for this job.\n\n" +
|
||||
"Use --hypothesis, --context, --intent flags when queuing jobs.")
|
||||
}
|
||||
|
||||
// wrapText wraps text to maxWidth, preserving newlines
|
||||
func wrapText(text string, maxWidth int) string {
|
||||
if maxWidth <= 0 {
|
||||
return text
|
||||
}
|
||||
|
||||
lines := strings.Split(text, "\n")
|
||||
var result []string
|
||||
|
||||
for _, line := range lines {
|
||||
if len(line) <= maxWidth {
|
||||
result = append(result, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Simple word wrap
|
||||
words := strings.Fields(line)
|
||||
var currentLine string
|
||||
for _, word := range words {
|
||||
if len(currentLine)+len(word)+1 > maxWidth {
|
||||
if currentLine != "" {
|
||||
result = append(result, currentLine)
|
||||
}
|
||||
currentLine = word
|
||||
} else {
|
||||
if currentLine != "" {
|
||||
currentLine += " "
|
||||
}
|
||||
currentLine += word
|
||||
}
|
||||
}
|
||||
if currentLine != "" {
|
||||
result = append(result, currentLine)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
package view
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
|
@ -180,6 +181,27 @@ func getRightPanel(m model.State, width int) string {
|
|||
style = activeBorderStyle
|
||||
viewTitle = "📦 Datasets"
|
||||
content = m.DatasetView.View()
|
||||
case model.ViewModeNarrative:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "📊 Research Context"
|
||||
narrativeView := &NarrativeView{Width: width, Height: m.Height}
|
||||
content = narrativeView.View(m.SelectedJob)
|
||||
case model.ViewModeTeam:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "👥 Team Jobs"
|
||||
content = "Team collaboration view - shows jobs from all team members\n\n(Requires API: GET /api/jobs?all_users=true)"
|
||||
case model.ViewModeExperimentHistory:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "📜 Experiment History"
|
||||
content = m.ExperimentHistoryView.View()
|
||||
case model.ViewModeConfig:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "⚙️ Config"
|
||||
content = m.ConfigView.View()
|
||||
case model.ViewModeLogs:
|
||||
style = activeBorderStyle
|
||||
viewTitle = "📜 Logs"
|
||||
content = m.LogsView.View()
|
||||
default:
|
||||
viewTitle = "📊 System Overview"
|
||||
content = getOverviewPanel(m)
|
||||
|
|
@ -218,7 +240,14 @@ func getStatusBar(m model.State) string {
|
|||
if m.ShowHelp {
|
||||
statusText = "Press 'h' to hide help"
|
||||
}
|
||||
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText)
|
||||
|
||||
// Add refresh rate indicator
|
||||
refreshInfo := ""
|
||||
if m.RefreshRate > 0 {
|
||||
refreshInfo = fmt.Sprintf(" | %.0fms", m.RefreshRate)
|
||||
}
|
||||
|
||||
return statusStyle.Width(m.Width - 4).Render(spinnerStr + " " + statusText + refreshInfo)
|
||||
}
|
||||
|
||||
func helpText(m model.State) string {
|
||||
|
|
@ -242,25 +271,28 @@ func helpText(m model.State) string {
|
|||
║ Navigation ║
|
||||
║ j/k, ↑/↓ : Move selection / : Filter jobs ║
|
||||
║ 1 : Job list view 2 : Datasets view ║
|
||||
║ 3 : Experiments view v : Queue view ║
|
||||
║ 3 : Experiments view q : Queue view ║
|
||||
║ g : GPU view o : Container view ║
|
||||
║ s : Settings view ║
|
||||
║ s : Settings view n : Narrative view ║
|
||||
║ m : Team view e : Experiment history ║
|
||||
║ c : Config view l : Logs (job stream) ║
|
||||
║ E : Export job @ : Filter by team ║
|
||||
║ ║
|
||||
║ Actions ║
|
||||
║ t : Queue job a : Queue w/ args ║
|
||||
║ c : Cancel task d : Delete pending ║
|
||||
║ x : Cancel task d : Delete pending ║
|
||||
║ f : Mark as failed r : Refresh all ║
|
||||
║ G : Refresh GPU only ║
|
||||
║ ║
|
||||
║ General ║
|
||||
║ h or ? : Toggle this help q/Ctrl+C : Quit ║
|
||||
║ h or ? : Toggle this help Ctrl+C : Quit ║
|
||||
╚═══════════════════════════════════════════════════════════════╝`
|
||||
}
|
||||
|
||||
func getQuickHelp(m model.State) string {
|
||||
if m.ActiveView == model.ViewModeSettings {
|
||||
return " ↑/↓:move enter:select esc:exit settings q:quit"
|
||||
return " ↑/↓:move enter:select esc:exit settings Ctrl+C:quit"
|
||||
}
|
||||
return " h:help 1:jobs 2:datasets 3:experiments v:queue g:gpu o:containers " +
|
||||
"s:settings t:queue r:refresh q:quit"
|
||||
return " h:help 1:jobs 2:datasets 3:experiments q:queue g:gpu o:containers " +
|
||||
"n:narrative m:team e:history c:config l:logs E:export @:filter s:settings t:trigger r:refresh Ctrl+C:quit"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
|
@ -41,6 +42,16 @@ func (m AppModel) View() string {
|
|||
}
|
||||
|
||||
func main() {
|
||||
// Redirect logs to file to prevent TUI disruption
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
logDir := filepath.Join(homeDir, ".ml", "logs")
|
||||
os.MkdirAll(logDir, 0755)
|
||||
logFile, logErr := os.OpenFile(filepath.Join(logDir, "tui.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if logErr == nil {
|
||||
log.SetOutput(logFile)
|
||||
defer logFile.Close()
|
||||
}
|
||||
|
||||
// Parse authentication flags
|
||||
authFlags := auth.ParseAuthFlags()
|
||||
if err := auth.ValidateFlags(authFlags); err != nil {
|
||||
|
|
@ -85,10 +96,10 @@ func main() {
|
|||
log.Printf(" 4. Run TUI: ./bin/tui")
|
||||
log.Printf("")
|
||||
log.Printf("Example ~/.ml/config.toml:")
|
||||
log.Printf(" worker_host = \"localhost\"")
|
||||
log.Printf(" worker_user = \"your_username\"")
|
||||
log.Printf(" worker_base = \"~/ml_jobs\"")
|
||||
log.Printf(" worker_port = 22")
|
||||
log.Printf(" mode = \"dev\"")
|
||||
log.Printf(" # Paths auto-resolve based on mode:")
|
||||
log.Printf(" # dev mode: ./data/dev/experiments")
|
||||
log.Printf(" # prod mode: ./data/prod/experiments")
|
||||
log.Printf(" api_key = \"your_api_key_here\"")
|
||||
log.Printf("")
|
||||
log.Printf("For more help, see: https://github.com/jfraeys/fetch_ml/docs")
|
||||
|
|
@ -101,6 +112,11 @@ func main() {
|
|||
}
|
||||
log.Printf("Loaded TOML configuration from %s", cliConfPath)
|
||||
|
||||
// Force local mode - TUI runs on server with direct filesystem access
|
||||
cfg.Host = ""
|
||||
// Clear BasePath to force mode-based path resolution
|
||||
cfg.BasePath = ""
|
||||
|
||||
// Validate authentication configuration
|
||||
if err := cfg.Auth.ValidateAuthConfig(); err != nil {
|
||||
log.Fatalf("Invalid authentication configuration: %v", err)
|
||||
|
|
@ -130,7 +146,7 @@ func main() {
|
|||
|
||||
srv, err := services.NewMLServer(cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to server: %v", err)
|
||||
log.Fatalf("Failed to initialize local server: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := srv.Close(); err != nil {
|
||||
|
|
@ -138,19 +154,24 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
// TaskQueue is optional for local mode
|
||||
tq, err := services.NewTaskQueue(cfg)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to Redis: %v", err)
|
||||
return
|
||||
log.Printf("Warning: Failed to connect to Redis: %v", err)
|
||||
log.Printf("Continuing without task queue functionality")
|
||||
tq = nil
|
||||
}
|
||||
if tq != nil {
|
||||
defer func() {
|
||||
if err := tq.Close(); err != nil {
|
||||
log.Printf("task queue close error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Initialize logger with error level and no debug output
|
||||
logger := logging.NewLogger(-4, false) // -4 = slog.LevelError
|
||||
// Initialize logger with file output only (prevents TUI disruption)
|
||||
logFilePath := filepath.Join(logDir, "tui.log")
|
||||
logger := logging.NewFileLogger(-4, false, logFilePath) // -4 = slog.LevelError
|
||||
|
||||
// Initialize State and Controller
|
||||
var effectiveAPIKey string
|
||||
|
|
@ -177,14 +198,17 @@ func main() {
|
|||
|
||||
go func() {
|
||||
<-sigChan
|
||||
logger.Info("Received shutdown signal, closing TUI...")
|
||||
p.Quit()
|
||||
}()
|
||||
|
||||
if _, err := p.Run(); err != nil {
|
||||
_ = p.ReleaseTerminal()
|
||||
log.Printf("Error running TUI: %v", err)
|
||||
logger.Error("Error running TUI", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure terminal is released and resources are closed via defer statements
|
||||
_ = p.ReleaseTerminal()
|
||||
logger.Info("TUI shutdown complete")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ security:
|
|||
api_key_rotation_days: 90
|
||||
audit_logging:
|
||||
enabled: true
|
||||
log_path: "/tmp/fetchml-audit.log"
|
||||
log_path: "data/dev/logs/fetchml-audit.log"
|
||||
rate_limit:
|
||||
enabled: false
|
||||
requests_per_minute: 60
|
||||
|
|
@ -46,8 +46,8 @@ database:
|
|||
|
||||
logging:
|
||||
level: "info"
|
||||
file: ""
|
||||
audit_log: ""
|
||||
file: "data/dev/logs/fetchml.log"
|
||||
audit_log: "data/dev/logs/fetchml-audit.log"
|
||||
|
||||
resources:
|
||||
max_workers: 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
base_path: "./data/experiments"
|
||||
base_path: "./data/dev/experiments"
|
||||
|
||||
data_dir: "./data/active"
|
||||
data_dir: "./data/dev/active"
|
||||
|
||||
auth:
|
||||
enabled: false
|
||||
|
|
@ -19,7 +19,7 @@ security:
|
|||
api_key_rotation_days: 90
|
||||
audit_logging:
|
||||
enabled: true
|
||||
log_path: "./data/fetchml-audit.log"
|
||||
log_path: "./data/dev/logs/fetchml-audit.log"
|
||||
rate_limit:
|
||||
enabled: false
|
||||
requests_per_minute: 60
|
||||
|
|
@ -42,12 +42,12 @@ redis:
|
|||
|
||||
database:
|
||||
type: "sqlite"
|
||||
connection: "./data/fetchml.sqlite"
|
||||
connection: "./data/dev/fetchml.sqlite"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
file: ""
|
||||
audit_log: ""
|
||||
file: "./data/dev/logs/fetchml.log"
|
||||
audit_log: "./data/dev/logs/fetchml-audit.log"
|
||||
|
||||
resources:
|
||||
max_workers: 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
base_path: "/app/data/experiments"
|
||||
base_path: "/app/data/prod/experiments"
|
||||
|
||||
data_dir: "/data/active"
|
||||
data_dir: "/app/data/prod/active"
|
||||
|
||||
auth:
|
||||
enabled: true
|
||||
|
|
@ -45,12 +45,12 @@ redis:
|
|||
|
||||
database:
|
||||
type: "sqlite"
|
||||
connection: "/app/data/experiments/fetch_ml.sqlite"
|
||||
connection: "/app/data/prod/fetch_ml.sqlite"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
file: "/logs/fetch_ml.log"
|
||||
audit_log: ""
|
||||
file: "/app/data/prod/logs/fetch_ml.log"
|
||||
audit_log: "/app/data/prod/logs/audit.log"
|
||||
|
||||
resources:
|
||||
max_workers: 2
|
||||
|
|
|
|||
|
|
@ -36,6 +36,22 @@ func (s *Server) registerRoutes(mux *http.ServeMux) {
|
|||
|
||||
// Register HTTP API handlers
|
||||
s.handlers.RegisterHandlers(mux)
|
||||
|
||||
// Register new REST API endpoints for TUI
|
||||
jobsHandler := jobs.NewHandler(
|
||||
s.expManager,
|
||||
s.logger,
|
||||
s.taskQueue,
|
||||
s.db,
|
||||
s.config.BuildAuthConfig(),
|
||||
nil,
|
||||
)
|
||||
|
||||
// Experiment history endpoint: GET /api/experiments/:id/history
|
||||
mux.HandleFunc("GET /api/experiments/{id}/history", jobsHandler.GetExperimentHistoryHTTP)
|
||||
|
||||
// Team jobs endpoint: GET /api/jobs?all_users=true
|
||||
mux.HandleFunc("GET /api/jobs", jobsHandler.ListAllJobsHTTP)
|
||||
}
|
||||
|
||||
// registerHealthRoutes sets up health check endpoints
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jfraeys/fetch_ml/internal/audit"
|
||||
|
|
@ -110,6 +112,22 @@ const (
|
|||
PermJupyterRead = "jupyter:read"
|
||||
)
|
||||
|
||||
// ClientType represents the type of WebSocket client
|
||||
type ClientType int
|
||||
|
||||
const (
|
||||
ClientTypeCLI ClientType = iota
|
||||
ClientTypeTUI
|
||||
)
|
||||
|
||||
// Client represents a connected WebSocket client
|
||||
type Client struct {
|
||||
conn *websocket.Conn
|
||||
Type ClientType
|
||||
User string
|
||||
RemoteAddr string
|
||||
}
|
||||
|
||||
// Handler provides WebSocket handling
|
||||
type Handler struct {
|
||||
authConfig *auth.Config
|
||||
|
|
@ -125,6 +143,10 @@ type Handler struct {
|
|||
jobsHandler *jobs.Handler
|
||||
jupyterHandler *jupyterj.Handler
|
||||
datasetsHandler *datasets.Handler
|
||||
|
||||
// Client management for push updates
|
||||
clients map[*Client]bool
|
||||
clientsMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHandler creates a new WebSocket handler
|
||||
|
|
@ -158,6 +180,7 @@ func NewHandler(
|
|||
jobsHandler: jobsHandler,
|
||||
jupyterHandler: jupyterHandler,
|
||||
datasetsHandler: datasetsHandler,
|
||||
clients: make(map[*Client]bool),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -217,6 +240,25 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
func (h *Handler) handleConnection(conn *websocket.Conn) {
|
||||
h.logger.Info("websocket connection established", "remote", conn.RemoteAddr())
|
||||
|
||||
// Register client
|
||||
client := &Client{
|
||||
conn: conn,
|
||||
Type: ClientTypeTUI, // Assume TUI for now, could detect from handshake
|
||||
User: "tui-user",
|
||||
RemoteAddr: conn.RemoteAddr().String(),
|
||||
}
|
||||
|
||||
h.clientsMu.Lock()
|
||||
h.clients[client] = true
|
||||
h.clientsMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
h.clientsMu.Lock()
|
||||
delete(h.clients, client)
|
||||
h.clientsMu.Unlock()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
messageType, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
|
|
@ -765,3 +807,66 @@ func (h *Handler) handleSetRunOutcome(conn *websocket.Conn, payload []byte) erro
|
|||
"message": "Outcome updated",
|
||||
})
|
||||
}
|
||||
|
||||
// BroadcastJobUpdate sends job status update to all connected TUI clients
|
||||
func (h *Handler) BroadcastJobUpdate(jobName, status string, progress int) {
|
||||
h.clientsMu.RLock()
|
||||
defer h.clientsMu.RUnlock()
|
||||
|
||||
msg := map[string]any{
|
||||
"type": "job_update",
|
||||
"job_name": jobName,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"time": time.Now().Unix(),
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(msg)
|
||||
|
||||
for client := range h.clients {
|
||||
if client.Type == ClientTypeTUI {
|
||||
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
|
||||
h.logger.Warn("failed to broadcast to client", "error", err, "client", client.RemoteAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGPUUpdate sends GPU status update to all connected TUI clients
|
||||
func (h *Handler) BroadcastGPUUpdate(deviceID, utilization int, memoryUsed, memoryTotal int64) {
|
||||
h.clientsMu.RLock()
|
||||
defer h.clientsMu.RUnlock()
|
||||
|
||||
msg := map[string]any{
|
||||
"type": "gpu_update",
|
||||
"device_id": deviceID,
|
||||
"utilization": utilization,
|
||||
"memory_used": memoryUsed,
|
||||
"memory_total": memoryTotal,
|
||||
"time": time.Now().Unix(),
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(msg)
|
||||
|
||||
for client := range h.clients {
|
||||
if client.Type == ClientTypeTUI {
|
||||
if err := client.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
|
||||
h.logger.Warn("failed to broadcast GPU update", "error", err, "client", client.RemoteAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectedClientCount returns the number of connected TUI clients
|
||||
func (h *Handler) GetConnectedClientCount() int {
|
||||
h.clientsMu.RLock()
|
||||
defer h.clientsMu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for client := range h.clients {
|
||||
if client.Type == ClientTypeTUI {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
|
|
|||
|
|
@ -229,6 +229,37 @@ func (s *SmartDefaults) ExpandPath(path string) string {
|
|||
return path
|
||||
}
|
||||
|
||||
// ModeBasedPath returns the appropriate path for a given mode
|
||||
func ModeBasedPath(mode string, subpath string) string {
|
||||
switch mode {
|
||||
case "dev":
|
||||
return filepath.Join("./data/dev", subpath)
|
||||
case "prod":
|
||||
return filepath.Join("./data/prod", subpath)
|
||||
case "ci":
|
||||
return filepath.Join("./data/ci", subpath)
|
||||
case "prod-smoke":
|
||||
return filepath.Join("./data/prod-smoke", subpath)
|
||||
default:
|
||||
return filepath.Join("./data/dev", subpath)
|
||||
}
|
||||
}
|
||||
|
||||
// ModeBasedBasePath returns the experiments base path for a given mode
|
||||
func ModeBasedBasePath(mode string) string {
|
||||
return ModeBasedPath(mode, "experiments")
|
||||
}
|
||||
|
||||
// ModeBasedDataDir returns the data directory for a given mode
|
||||
func ModeBasedDataDir(mode string) string {
|
||||
return ModeBasedPath(mode, "active")
|
||||
}
|
||||
|
||||
// ModeBasedLogDir returns the logs directory for a given mode
|
||||
func ModeBasedLogDir(mode string) string {
|
||||
return ModeBasedPath(mode, "logs")
|
||||
}
|
||||
|
||||
// GetEnvironmentDescription returns a human-readable description
|
||||
func (s *SmartDefaults) GetEnvironmentDescription() string {
|
||||
switch s.Profile {
|
||||
|
|
|
|||
Loading…
Reference in a new issue