Compare commits
72 commits
a64233d4f6
...
ed7b5032a9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed7b5032a9 | ||
|
|
be39b37aec | ||
|
|
1a1844e9e9 | ||
|
|
b1c9bc97fc | ||
|
|
382c67edfc | ||
|
|
ccd1dd7a4d | ||
|
|
20fde4f79d | ||
|
|
c56e53cb52 | ||
|
|
05b7af6991 | ||
|
|
d6265df0bd | ||
|
|
e557313e08 | ||
|
|
5f8e7c59a5 | ||
|
|
fa383ebc6f | ||
|
|
158c525bef | ||
|
|
90d702823b | ||
|
|
d1ac558107 | ||
|
|
48d00b8322 | ||
|
|
25ae791b5c | ||
|
|
1a35c54300 | ||
|
|
4b2ee75072 | ||
|
|
c89d970210 | ||
|
|
472590f831 | ||
|
|
7efe8bbfbf | ||
|
|
201cb66f56 | ||
|
|
a3b957dcc0 | ||
|
|
04ac745b01 | ||
|
|
7c4a59012b | ||
|
|
adf4c2a834 | ||
|
|
d3461cd07f | ||
|
|
f5b68cca49 | ||
|
|
d0c68772ea | ||
|
|
551597b5df | ||
|
|
d43725b817 | ||
|
|
96c4c376d8 | ||
|
|
23e5f3d1dc | ||
|
|
7583932897 | ||
|
|
2258f60ade | ||
|
|
7ce0fd251e | ||
|
|
2c596038b5 | ||
|
|
ff542b533f | ||
|
|
6028779239 | ||
|
|
02811c0ffe | ||
|
|
37aad7ae87 | ||
|
|
a3f9bf8731 | ||
|
|
e4d286f2e5 | ||
|
|
34aaba8f17 | ||
|
|
f357624685 | ||
|
|
27c8b08a16 | ||
|
|
4756348c48 | ||
|
|
cb826b74a3 | ||
|
|
b2eba75f09 | ||
|
|
aaeef69bab | ||
|
|
260e18499e | ||
|
|
94020e4ca4 | ||
|
|
8b75f71a6a | ||
|
|
5e8dc08643 | ||
|
|
b4672a6c25 | ||
|
|
38b6c3323a | ||
|
|
d9ed8f4ffa | ||
|
|
f7afb36a7c | ||
|
|
33b893a71a | ||
|
|
a5059c5231 | ||
|
|
4bee42493b | ||
|
|
2101e4a01c | ||
|
|
3e744bf312 | ||
|
|
e127f97442 | ||
|
|
64e306bd72 | ||
|
|
7880ea8d79 | ||
|
|
5644338ebd | ||
|
|
c9b6532dfb | ||
|
|
412d7b82e9 | ||
|
|
6446379a40 |
285 changed files with 25097 additions and 3642 deletions
80
.forgejo/workflows/build-cli.yml
Normal file
80
.forgejo/workflows/build-cli.yml
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
name: Build CLI with Embedded SQLite
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, master]
|
||||
paths:
|
||||
- 'cli/**'
|
||||
- '.forgejo/workflows/build-cli.yml'
|
||||
pull_request:
|
||||
branches: [main, master]
|
||||
paths:
|
||||
- 'cli/**'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
target:
|
||||
- x86_64-linux
|
||||
- aarch64-linux
|
||||
include:
|
||||
- target: x86_64-linux
|
||||
arch: x86_64
|
||||
- target: aarch64-linux
|
||||
arch: arm64
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Zig
|
||||
uses: goto-bus-stop/setup-zig@v2
|
||||
with:
|
||||
version: 0.15.0
|
||||
|
||||
- name: Fetch SQLite Amalgamation
|
||||
run: |
|
||||
cd cli
|
||||
make build-sqlite SQLITE_VERSION=3450000
|
||||
|
||||
- name: Build Release Binary
|
||||
run: |
|
||||
cd cli
|
||||
zig build prod -Dtarget=${{ matrix.target }}
|
||||
|
||||
- name: Upload Artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ml-cli-${{ matrix.target }}
|
||||
path: cli/zig-out/bin/ml
|
||||
|
||||
build-macos:
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86_64, arm64]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Zig
|
||||
uses: goto-bus-stop/setup-zig@v2
|
||||
with:
|
||||
version: 0.15.0
|
||||
|
||||
- name: Fetch SQLite Amalgamation
|
||||
run: |
|
||||
cd cli
|
||||
make build-sqlite SQLITE_VERSION=3450000
|
||||
|
||||
- name: Build Release Binary
|
||||
run: |
|
||||
cd cli
|
||||
zig build prod -Dtarget=${{ matrix.arch }}-macos
|
||||
|
||||
- name: Upload Artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ml-cli-${{ matrix.arch }}-macos
|
||||
path: cli/zig-out/bin/ml
|
||||
|
|
@ -197,33 +197,29 @@ jobs:
|
|||
- name: Test with Native Libraries
|
||||
run: |
|
||||
echo "Running tests WITH native libraries enabled..."
|
||||
FETCHML_NATIVE_LIBS=1 go test -v ./tests/...
|
||||
env:
|
||||
FETCHML_NATIVE_LIBS: "1"
|
||||
CGO_ENABLED=1 go test -tags native_libs -v ./tests/...
|
||||
continue-on-error: true
|
||||
|
||||
- name: Native Smoke Test
|
||||
run: |
|
||||
echo "Running native libraries smoke test..."
|
||||
make native-smoke
|
||||
env:
|
||||
FETCHML_NATIVE_LIBS: "1"
|
||||
CGO_ENABLED=1 go test -tags native_libs ./tests/benchmarks/... -run TestNative
|
||||
continue-on-error: true
|
||||
|
||||
- name: Test Fallback (Go only)
|
||||
run: |
|
||||
echo "Running tests WITHOUT native libraries (Go fallback)..."
|
||||
FETCHML_NATIVE_LIBS=0 go test -v ./tests/...
|
||||
env:
|
||||
FETCHML_NATIVE_LIBS: "0"
|
||||
go test -v ./tests/...
|
||||
continue-on-error: true
|
||||
|
||||
- name: Run Benchmarks
|
||||
run: |
|
||||
echo "Running performance benchmarks..."
|
||||
echo "=== Go Implementation ==="
|
||||
FETCHML_NATIVE_LIBS=0 go test -bench=. ./tests/benchmarks/ -benchmem || true
|
||||
go test -bench=. ./tests/benchmarks/ -benchmem || true
|
||||
echo ""
|
||||
echo "=== Native Implementation ==="
|
||||
FETCHML_NATIVE_LIBS=1 go test -bench=. ./tests/benchmarks/ -benchmem || true
|
||||
CGO_ENABLED=1 go test -tags native_libs -bench=. ./tests/benchmarks/ -benchmem || true
|
||||
|
||||
- name: Lint
|
||||
run: |
|
||||
|
|
|
|||
120
.forgejo/workflows/contract-test.yml
Normal file
120
.forgejo/workflows/contract-test.yml
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
name: Contract Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths:
|
||||
- 'api/openapi.yaml'
|
||||
- 'internal/api/server_gen.go'
|
||||
- 'internal/api/adapter.go'
|
||||
- '.forgejo/workflows/contract-test.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'api/openapi.yaml'
|
||||
- 'internal/api/server_gen.go'
|
||||
- 'internal/api/adapter.go'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
GO_VERSION: '1.25.0'
|
||||
|
||||
jobs:
|
||||
spec-drift-check:
|
||||
name: Spec Drift Detection
|
||||
runs-on: self-hosted
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Go
|
||||
run: |
|
||||
go version || (wget https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz && \
|
||||
sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go${GO_VERSION}.linux-amd64.tar.gz)
|
||||
echo "PATH=$PATH:/usr/local/go/bin" >> $GITHUB_ENV
|
||||
go version
|
||||
|
||||
- name: Install oapi-codegen
|
||||
run: |
|
||||
go install github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen@latest
|
||||
echo "PATH=$PATH:$(go env GOPATH)/bin" >> $GITHUB_ENV
|
||||
|
||||
- name: Verify spec matches implementation
|
||||
run: make openapi-check-implementation
|
||||
|
||||
contract-test:
|
||||
name: API Contract Tests
|
||||
runs-on: self-hosted
|
||||
timeout-minutes: 15
|
||||
needs: spec-drift-check
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Go
|
||||
run: |
|
||||
go version || (wget https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz && \
|
||||
sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go${GO_VERSION}.linux-amd64.tar.gz)
|
||||
echo "PATH=$PATH:/usr/local/go/bin" >> $GITHUB_ENV
|
||||
go version
|
||||
|
||||
- name: Build API server
|
||||
run: |
|
||||
go build -o api-server ./cmd/api-server
|
||||
|
||||
- name: Start API server
|
||||
run: |
|
||||
./api-server --config configs/api/dev-local.yaml &
|
||||
echo "API_PID=$!" >> $GITHUB_ENV
|
||||
sleep 5
|
||||
|
||||
- name: Install schemathesis
|
||||
run: |
|
||||
pip install schemathesis
|
||||
|
||||
- name: Run contract tests
|
||||
run: |
|
||||
schemathesis run api/openapi.yaml \
|
||||
--base-url http://localhost:8080 \
|
||||
--checks all \
|
||||
--max-response-time 5000 \
|
||||
--hypothesis-max-examples 50
|
||||
continue-on-error: true # Allow failures until all endpoints are fully implemented
|
||||
|
||||
- name: Stop API server
|
||||
if: always()
|
||||
run: |
|
||||
if [ -n "$API_PID" ]; then
|
||||
kill $API_PID || true
|
||||
fi
|
||||
|
||||
- name: Basic endpoint verification
|
||||
run: |
|
||||
echo "Testing /health endpoint..."
|
||||
curl -s http://localhost:8080/health || echo "Server not running, skipping"
|
||||
|
||||
echo "Testing /v1/experiments endpoint..."
|
||||
curl -s http://localhost:8080/v1/experiments || echo "Server not running, skipping"
|
||||
|
||||
echo "Testing /v1/tasks endpoint..."
|
||||
curl -s http://localhost:8080/v1/tasks || echo "Server not running, skipping"
|
||||
continue-on-error: true
|
||||
63
.forgejo/workflows/docs-deploy.yml
Normal file
63
.forgejo/workflows/docs-deploy.yml
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
name: Deploy API Docs
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'api/openapi.yaml'
|
||||
- '.forgejo/workflows/docs-deploy.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-docs:
|
||||
name: Build API Documentation
|
||||
runs-on: self-hosted
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install Redocly CLI
|
||||
run: npm install -g @redocly/cli
|
||||
|
||||
- name: Generate documentation
|
||||
run: |
|
||||
mkdir -p docs/api
|
||||
redocly build-docs api/openapi.yaml \
|
||||
--output docs/api/index.html \
|
||||
--title "FetchML API"
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-pages-artifact@v3
|
||||
with:
|
||||
path: docs/api
|
||||
|
||||
deploy-docs:
|
||||
name: Deploy to GitHub Pages
|
||||
needs: build-docs
|
||||
runs-on: self-hosted
|
||||
timeout-minutes: 10
|
||||
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
|
|
@ -74,7 +74,7 @@ jobs:
|
|||
*) echo "Unsupported Zig target: $TARGET"; exit 1 ;;
|
||||
esac
|
||||
|
||||
RSYNC_OUT="cli/src/assets/rsync_release_${OS}_${ARCH}.bin"
|
||||
RSYNC_OUT="cli/src/assets/rsync/rsync_release_${OS}_${ARCH}.bin"
|
||||
|
||||
wget -O "$RSYNC_OUT" ${{ matrix.rsync-url }} || \
|
||||
curl -L -o "$RSYNC_OUT" ${{ matrix.rsync-url }}
|
||||
|
|
@ -83,6 +83,37 @@ jobs:
|
|||
chmod +x "$RSYNC_OUT"
|
||||
ls -lh "$RSYNC_OUT"
|
||||
|
||||
- name: Download SQLite amalgamation
|
||||
run: |
|
||||
TARGET="${{ matrix.target }}"
|
||||
OS=""
|
||||
ARCH=""
|
||||
case "$TARGET" in
|
||||
x86_64-linux-*) OS="linux"; ARCH="x86_64" ;;
|
||||
aarch64-linux-*) OS="linux"; ARCH="arm64" ;;
|
||||
x86_64-macos*) OS="darwin"; ARCH="x86_64" ;;
|
||||
aarch64-macos*) OS="darwin"; ARCH="arm64" ;;
|
||||
x86_64-windows*) OS="windows"; ARCH="x86_64" ;;
|
||||
aarch64-windows*) OS="windows"; ARCH="arm64" ;;
|
||||
*) echo "Unsupported Zig target: $TARGET"; exit 1 ;;
|
||||
esac
|
||||
|
||||
SQLITE_VERSION="3480000"
|
||||
SQLITE_YEAR="2025"
|
||||
SQLITE_URL="https://www.sqlite.org/${SQLITE_YEAR}/sqlite-amalgamation-${SQLITE_VERSION}.zip"
|
||||
SQLITE_DIR="cli/src/assets/sqlite_${OS}_${ARCH}"
|
||||
|
||||
mkdir -p "$SQLITE_DIR"
|
||||
|
||||
echo "Fetching SQLite ${SQLITE_VERSION}..."
|
||||
wget -O /tmp/sqlite.zip "$SQLITE_URL" || \
|
||||
curl -L -o /tmp/sqlite.zip "$SQLITE_URL"
|
||||
|
||||
unzip -q /tmp/sqlite.zip -d /tmp/
|
||||
mv /tmp/sqlite-amalgamation-${SQLITE_VERSION}/* "$SQLITE_DIR/"
|
||||
|
||||
ls -lh "$SQLITE_DIR"/sqlite3.c "$SQLITE_DIR"/sqlite3.h
|
||||
|
||||
- name: Build CLI
|
||||
working-directory: cli
|
||||
run: |
|
||||
|
|
|
|||
90
.forgejo/workflows/security-scan.yml
Normal file
90
.forgejo/workflows/security-scan.yml
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
name: Security Scan
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main, develop]
|
||||
|
||||
jobs:
|
||||
security:
|
||||
name: Security Analysis
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.25'
|
||||
|
||||
- name: Run govulncheck
|
||||
uses: golang/govulncheck-action@v1
|
||||
with:
|
||||
go-version-input: '1.25'
|
||||
go-package: ./...
|
||||
|
||||
- name: Run gosec
|
||||
uses: securego/gosec@master
|
||||
with:
|
||||
args: '-fmt sarif -out gosec-results.sarif ./...'
|
||||
|
||||
- name: Upload gosec results
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: gosec-results
|
||||
path: gosec-results.sarif
|
||||
|
||||
- name: Check for unsafe package usage
|
||||
run: |
|
||||
if grep -r "unsafe\." --include="*.go" ./internal ./cmd ./pkg 2>/dev/null; then
|
||||
echo "ERROR: unsafe package usage detected"
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ No unsafe package usage found"
|
||||
|
||||
- name: Verify dependencies
|
||||
run: |
|
||||
go mod verify
|
||||
echo "✓ Go modules verified"
|
||||
|
||||
native-security:
|
||||
name: Native Library Security
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y cmake build-essential
|
||||
|
||||
- name: Build with AddressSanitizer
|
||||
run: |
|
||||
cd native
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=Debug -DENABLE_ASAN=ON
|
||||
make -j$(nproc)
|
||||
|
||||
- name: Run tests with ASan
|
||||
run: |
|
||||
cd native/build
|
||||
ASAN_OPTIONS=detect_leaks=1 ctest --output-on-failure
|
||||
|
||||
- name: Build with UndefinedBehaviorSanitizer
|
||||
run: |
|
||||
cd native
|
||||
rm -rf build
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_C_FLAGS="-fsanitize=undefined" -DCMAKE_CXX_FLAGS="-fsanitize=undefined"
|
||||
make -j$(nproc)
|
||||
|
||||
- name: Run tests with UBSan
|
||||
run: |
|
||||
cd native/build
|
||||
ctest --output-on-failure
|
||||
17
.gitignore
vendored
17
.gitignore
vendored
|
|
@ -1,3 +1,11 @@
|
|||
# Root directory protection - binaries must be in bin/
|
||||
/api-server
|
||||
/worker
|
||||
/tui
|
||||
/data_manager
|
||||
/coverage.out
|
||||
.DS_Store
|
||||
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
|
|
@ -11,8 +19,8 @@
|
|||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
# Generated SSH test keys for TUI testing
|
||||
deployments/test_keys/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
|
@ -237,8 +245,9 @@ db/*.db
|
|||
*.key
|
||||
*.pem
|
||||
secrets/
|
||||
cli/src/assets/rsync_release.bin
|
||||
cli/src/assets/rsync_release_*.bin
|
||||
# Downloaded assets (platform-specific)
|
||||
cli/src/assets/rsync/rsync_release_*.bin
|
||||
cli/src/assets/sqlite_*/
|
||||
|
||||
# Local artifacts (e.g. test run outputs)
|
||||
.local-artifacts/
|
||||
|
|
|
|||
31
CHANGELOG.md
31
CHANGELOG.md
|
|
@ -1,5 +1,36 @@
|
|||
## [Unreleased]
|
||||
|
||||
### Added - CSV Export Features (2026-02-18)
|
||||
- CLI: `ml compare --csv` - Export run comparisons as CSV with actual run IDs as column headers
|
||||
- CLI: `ml find --csv` - Export search results as CSV for spreadsheet analysis
|
||||
- CLI: `ml dataset verify --csv` - Export dataset verification metrics as CSV
|
||||
- Shell: Updated bash/zsh completions with --csv flags for compare, find commands
|
||||
|
||||
### Added - Phase 3 Features (2026-02-18)
|
||||
- CLI: `ml requeue --with-changes` - Iterative experimentation with config overrides (--lr=0.002, etc.)
|
||||
- CLI: `ml requeue --inherit-narrative` - Copy hypothesis/context from parent run
|
||||
- CLI: `ml requeue --inherit-config` - Copy metadata from parent run
|
||||
- CLI: `ml requeue --parent` - Link as child run for provenance tracking
|
||||
- CLI: `ml dataset verify` - Fast dataset checksum validation
|
||||
- CLI: `ml logs --follow` - Real-time log streaming via WebSocket
|
||||
- API/WebSocket: Add opcodes for compare (0x30), find (0x31), export (0x32), set outcome (0x33)
|
||||
|
||||
### Added - Phase 2 Features (2026-02-18)
|
||||
- CLI: `ml compare` - Diff two runs showing narrative/metadata/metrics differences
|
||||
- CLI: `ml find` - Search experiments by tags, outcome, dataset, experiment-group, author
|
||||
- CLI: `ml export --anonymize` - Export bundles with path/IP/username redaction
|
||||
- CLI: `ml export --anonymize-level` - 'metadata-only' or 'full' anonymization
|
||||
- CLI: `ml outcome set` - Post-run outcome tracking (validates/refutes/inconclusive/partial)
|
||||
- CLI: Error suggestions with Levenshtein distance for typos
|
||||
- Shell: Updated bash/zsh completions for all new commands
|
||||
- Tests: E2E tests for compare, find, export, requeue changes
|
||||
|
||||
### Added - Phase 0 Features (2026-02-18)
|
||||
- CLI: Queue-time narrative flags (--hypothesis, --context, --intent, --expected-outcome, --experiment-group, --tags)
|
||||
- CLI: Enhanced `ml status` output with queue position [pos N] and priority (P:N)
|
||||
- CLI: `ml narrative set` command for setting run narrative fields
|
||||
- Shell: Updated completions with new commands and flags
|
||||
|
||||
### Security
|
||||
- Native: fix buffer overflow vulnerabilities in `dataset_hash` (replaced `strcpy` with `strncpy` + null termination)
|
||||
- Native: fix unsafe `memcpy` in `queue_index` priority queue (added explicit null terminators for string fields)
|
||||
|
|
|
|||
215
Makefile
215
Makefile
|
|
@ -1,4 +1,4 @@
|
|||
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs
|
||||
.PHONY: all build prod prod-with-native native-release native-build native-debug native-test native-smoke native-clean dev clean clean-docs test test-unit test-integration test-e2e test-coverage lint install configlint worker-configlint ci-local docs docs-setup docs-check-port docs-stop docs-build docs-build-prod benchmark benchmark-local artifacts clean-benchmarks clean-all clean-aggressive status size load-test chaos-test profile-load profile-load-norate profile-ws-queue profile-tools detect-regressions tech-excellence docker-build dev-smoke prod-smoke native-smoke self-cleanup test-full test-auth deploy-up deploy-down deploy-status deploy-clean dev-up dev-down dev-status dev-logs prod-up prod-down prod-status prod-logs security-scan gosec govulncheck check-unsafe security-audit test-security check-sqlbuild
|
||||
OK = ✓
|
||||
DOCS_PORT ?= 1313
|
||||
DOCS_BIND ?= 127.0.0.1
|
||||
|
|
@ -8,7 +8,7 @@ DOCS_PROD_BASEURL ?= $(DOCS_BASEURL)
|
|||
all: build
|
||||
|
||||
# Build all components (Go binaries + optimized CLI)
|
||||
build:
|
||||
build: native-build
|
||||
go build -ldflags="-X main.BuildHash=$(shell git rev-parse --short HEAD) -X main.BuildTime=$(shell date -u +%Y%m%d.%H%M%S)" -o bin/api-server ./cmd/api-server/main.go
|
||||
go build -ldflags="-X main.BuildHash=$(shell git rev-parse --short HEAD) -X main.BuildTime=$(shell date -u +%Y%m%d.%H%M%S)" -o bin/worker ./cmd/worker/worker_server.go
|
||||
go build -ldflags="-X main.BuildHash=$(shell git rev-parse --short HEAD) -X main.BuildTime=$(shell date -u +%Y%m%d.%H%M%S)" -o bin/data_manager ./cmd/data_manager
|
||||
|
|
@ -75,7 +75,7 @@ native-test: native-build
|
|||
|
||||
# Run native libraries smoke test (builds + C++ tests + Go integration)
|
||||
native-smoke:
|
||||
@bash ./scripts/smoke-test-native.sh
|
||||
@bash ./scripts/dev/smoke-test.sh --native
|
||||
@echo "${OK} Native smoke test passed"
|
||||
|
||||
# Build production-optimized binaries
|
||||
|
|
@ -125,6 +125,27 @@ clean:
|
|||
go clean
|
||||
@echo "${OK} Cleaned"
|
||||
|
||||
# Thorough cleanup for release
|
||||
clean-release: clean
|
||||
@echo "Release cleanup..."
|
||||
rm -rf bin/ cli/zig-out/ cli/.zig-cache/
|
||||
rm -rf dist/ coverage/ tests/bin/ native/build/
|
||||
go clean -cache -testcache
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -name "*.pyc" -delete 2>/dev/null || true
|
||||
@./scripts/release/cleanup.sh testdata 2>/dev/null || true
|
||||
@./scripts/release/cleanup.sh logs 2>/dev/null || true
|
||||
@./scripts/release/cleanup.sh state 2>/dev/null || true
|
||||
@echo "${OK} Release cleanup complete"
|
||||
|
||||
# Run release verification checks
|
||||
release-check:
|
||||
@./scripts/release/verify.sh
|
||||
|
||||
# Full release preparation
|
||||
prepare-release: clean-release release-check
|
||||
@echo "${OK} Release preparation complete"
|
||||
|
||||
clean-docs:
|
||||
rm -rf docs/_site/
|
||||
@echo "${OK} Cleaned docs"
|
||||
|
|
@ -171,17 +192,22 @@ worker-configlint:
|
|||
configs/workers/docker-prod.yaml \
|
||||
configs/workers/homelab-secure.yaml
|
||||
|
||||
# Check CLI builds correctly (SQLite handled automatically by build.zig)
|
||||
build-cli:
|
||||
@$(MAKE) -C cli all
|
||||
@echo "${OK} CLI built successfully"
|
||||
|
||||
dev-smoke:
|
||||
bash ./scripts/smoke-test.sh dev
|
||||
bash ./scripts/dev/smoke-test.sh dev
|
||||
@echo "dev smoke: OK"
|
||||
|
||||
prod-smoke:
|
||||
bash ./scripts/smoke-test.sh prod
|
||||
bash ./scripts/dev/smoke-test.sh prod
|
||||
@echo "prod smoke: OK"
|
||||
|
||||
# Run maintainability CI checks
|
||||
ci-checks:
|
||||
@bash ./scripts/ci-checks.sh
|
||||
@bash ./scripts/ci/checks.sh
|
||||
|
||||
# Run a local approximation of the CI pipeline
|
||||
ci-local: ci-checks test lint configlint worker-configlint
|
||||
|
|
@ -265,7 +291,7 @@ benchmark-compare:
|
|||
# Manage benchmark artifacts
|
||||
artifacts:
|
||||
@echo "Managing benchmark artifacts..."
|
||||
./scripts/manage-artifacts.sh help
|
||||
./scripts/dev/manage-artifacts.sh help
|
||||
|
||||
# Clean benchmark artifacts (keep last 10)
|
||||
clean-benchmarks:
|
||||
|
|
@ -360,6 +386,8 @@ help:
|
|||
@echo " make prod-with-native - Build production binaries with native C++ libs"
|
||||
@echo " make dev - Build development binaries (faster)"
|
||||
@echo " make clean - Remove build artifacts"
|
||||
@echo " make clean-release - Thorough cleanup for release"
|
||||
@echo " make prepare-release - Full release preparation (cleanup + verify)"
|
||||
@echo ""
|
||||
@echo "Native Library Targets:"
|
||||
@echo " make native-build - Build native C++ libraries"
|
||||
|
|
@ -475,3 +503,176 @@ prod-status:
|
|||
|
||||
prod-logs:
|
||||
@./deployments/deploy.sh prod logs
|
||||
|
||||
# =============================================================================
|
||||
# SECURITY TARGETS
|
||||
# =============================================================================
|
||||
|
||||
.PHONY: security-scan gosec govulncheck check-unsafe security-audit
|
||||
|
||||
# Run all security scans
|
||||
security-scan: gosec govulncheck check-unsafe
|
||||
@echo "${OK} Security scan complete"
|
||||
|
||||
# Run gosec security linter
|
||||
gosec:
|
||||
@mkdir -p reports
|
||||
@echo "Running gosec security scan..."
|
||||
@if command -v gosec >/dev/null 2>&1; then \
|
||||
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
|
||||
gosec -fmt=sarif -out=reports/gosec-results.sarif ./... 2>/dev/null || true; \
|
||||
gosec ./... 2>/dev/null || echo "Note: gosec found issues (see reports/gosec-results.json)"; \
|
||||
else \
|
||||
echo "Installing gosec..."; \
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest; \
|
||||
gosec -fmt=json -out=reports/gosec-results.json ./... 2>/dev/null || true; \
|
||||
fi
|
||||
@echo "${OK} gosec scan complete (see reports/gosec-results.*)"
|
||||
|
||||
# Run govulncheck for known vulnerabilities
|
||||
govulncheck:
|
||||
@echo "Running govulncheck for known vulnerabilities..."
|
||||
@if command -v govulncheck >/dev/null 2>&1; then \
|
||||
govulncheck ./...; \
|
||||
else \
|
||||
echo "Installing govulncheck..."; \
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest; \
|
||||
govulncheck ./...; \
|
||||
fi
|
||||
@echo "${OK} govulncheck complete"
|
||||
|
||||
# Check for unsafe package usage
|
||||
check-unsafe:
|
||||
@echo "Checking for unsafe package usage..."
|
||||
@if grep -r "unsafe\." --include="*.go" ./internal ./cmd ./pkg 2>/dev/null; then \
|
||||
echo "WARNING: Found unsafe package usage (review required)"; \
|
||||
exit 1; \
|
||||
else \
|
||||
echo "${OK} No unsafe package usage found"; \
|
||||
fi
|
||||
|
||||
# Full security audit (tests + scans)
|
||||
security-audit: security-scan test-security
|
||||
@echo "${OK} Full security audit complete"
|
||||
|
||||
# Run security-specific tests
|
||||
test-security:
|
||||
@echo "Running security tests..."
|
||||
@go test -v ./tests/security/... 2>/dev/null || echo "Note: No security tests yet (will be added in Phase 5)"
|
||||
@echo "${OK} Security tests complete"
|
||||
|
||||
# =============================================================================
|
||||
# API / OPENAPI TARGETS
|
||||
# =============================================================================
|
||||
|
||||
.PHONY: openapi-generate openapi-validate
|
||||
|
||||
# Generate Go types from OpenAPI spec (default, safe)
|
||||
openapi-generate:
|
||||
@echo "Generating Go types from OpenAPI spec..."
|
||||
@if command -v oapi-codegen >/dev/null 2>&1; then \
|
||||
oapi-codegen -package api -generate types api/openapi.yaml > internal/api/types.go; \
|
||||
echo "${OK} Generated internal/api/types.go"; \
|
||||
else \
|
||||
echo "Installing oapi-codegen v2..."; \
|
||||
go install github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen@latest; \
|
||||
oapi-codegen -package api -generate types api/openapi.yaml > internal/api/types.go; \
|
||||
echo "${OK} Generated internal/api/types.go"; \
|
||||
fi
|
||||
|
||||
# Generate full server interfaces (for Phase 5 migration)
|
||||
openapi-generate-server:
|
||||
@echo "Generating Go server code from OpenAPI spec..."
|
||||
@if command -v oapi-codegen >/dev/null 2>&1; then \
|
||||
oapi-codegen -package api -generate types,server,spec api/openapi.yaml > internal/api/server_gen.go; \
|
||||
echo "${OK} Generated internal/api/server_gen.go"; \
|
||||
else \
|
||||
echo "Installing oapi-codegen v2..."; \
|
||||
go install github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen@latest; \
|
||||
oapi-codegen -package api -generate types,server,spec api/openapi.yaml > internal/api/server_gen.go; \
|
||||
echo "${OK} Generated internal/api/server_gen.go"; \
|
||||
fi
|
||||
|
||||
# Validate OpenAPI spec against schema
|
||||
openapi-validate:
|
||||
@echo "Validating OpenAPI spec..."
|
||||
@if command -v openapi-generator >/dev/null 2>&1; then \
|
||||
openapi-generator validate -i api/openapi.yaml 2>/dev/null || echo "Note: Validation warnings (non-fatal)"; \
|
||||
else \
|
||||
echo "Note: Install openapi-generator for validation (https://openapi-generator.tech)"; \
|
||||
fi
|
||||
@echo "${OK} OpenAPI validation complete"
|
||||
|
||||
# CI check: fail if openapi.yaml changed but types.go didn't
|
||||
openapi-check-ci: openapi-generate
|
||||
@git diff --exit-code internal/api/types.go 2>/dev/null || \
|
||||
(echo "ERROR: OpenAPI spec changed but generated types not updated. Run 'make openapi-generate'" && exit 1)
|
||||
@echo "${OK} OpenAPI types are up to date"
|
||||
|
||||
# CI check: verify spec matches implementation (generated code is in sync)
|
||||
openapi-check-implementation:
|
||||
@echo "Verifying spec matches implementation..."
|
||||
@# Generate fresh types
|
||||
$(MAKE) openapi-generate-server
|
||||
@# Check for drift
|
||||
@git diff --exit-code internal/api/server_gen.go 2>/dev/null || \
|
||||
(echo "ERROR: Implementation drift detected. Regenerate with 'make openapi-generate-server'" && exit 1)
|
||||
@echo "${OK} Spec and implementation in sync"
|
||||
|
||||
# Generate static HTML API documentation
|
||||
docs-generate:
|
||||
@echo "Generating API documentation..."
|
||||
@mkdir -p docs/api
|
||||
@if command -v npx >/dev/null 2>&1; then \
|
||||
npx @redocly/cli build-docs api/openapi.yaml \
|
||||
--output docs/api/index.html \
|
||||
--title "FetchML API" 2>/dev/null || \
|
||||
echo "Note: Redocly CLI not available. Install with: npm install -g @redocly/cli"; \
|
||||
else \
|
||||
echo "Note: npx not available. Install Node.js to generate docs"; \
|
||||
fi
|
||||
@echo "${OK} Docs generation complete"
|
||||
|
||||
# Serve API documentation locally
|
||||
docs-serve:
|
||||
@if command -v npx >/dev/null 2>&1; then \
|
||||
npx @redocly/cli preview-docs api/openapi.yaml; \
|
||||
else \
|
||||
echo "Note: npx not available. Install Node.js to serve docs"; \
|
||||
fi
|
||||
|
||||
# Generate TypeScript client SDK
|
||||
openapi-generate-ts:
|
||||
@echo "Generating TypeScript client..."
|
||||
@mkdir -p sdk/typescript
|
||||
@if command -v npx >/dev/null 2>&1; then \
|
||||
npx @openapitools/openapi-generator-cli generate \
|
||||
-i api/openapi.yaml \
|
||||
-g typescript-fetch \
|
||||
-o sdk/typescript \
|
||||
--additional-properties=supportsES6=true,npmName=fetchml-client 2>/dev/null || \
|
||||
echo "Note: openapi-generator-cli not available. Install with: npm install -g @openapitools/openapi-generator-cli"; \
|
||||
else \
|
||||
echo "Note: npx not available. Install Node.js to generate TypeScript client"; \
|
||||
fi
|
||||
@echo "${OK} TypeScript client generation complete"
|
||||
|
||||
# Generate Python client SDK
|
||||
openapi-generate-python:
|
||||
@echo "Generating Python client..."
|
||||
@mkdir -p sdk/python
|
||||
@if command -v npx >/dev/null 2>&1; then \
|
||||
npx @openapitools/openapi-generator-cli generate \
|
||||
-i api/openapi.yaml \
|
||||
-g python \
|
||||
-o sdk/python \
|
||||
--additional-properties=packageName=fetchml_client 2>/dev/null || \
|
||||
echo "Note: openapi-generator-cli not available. Install with: npm install -g @openapitools/openapi-generator-cli"; \
|
||||
else \
|
||||
echo "Note: npx not available. Install Node.js to generate Python client"; \
|
||||
fi
|
||||
@echo "${OK} Python client generation complete"
|
||||
|
||||
# Generate all client SDKs
|
||||
openapi-generate-clients: openapi-generate-ts openapi-generate-python
|
||||
@echo "${OK} All client SDKs generated"
|
||||
|
|
|
|||
28
README.md
28
README.md
|
|
@ -100,6 +100,15 @@ ml queue my-job
|
|||
ml cancel my-job
|
||||
ml dataset list
|
||||
ml monitor # SSH to run TUI remotely
|
||||
|
||||
# Research features (see docs/src/research-features.md)
|
||||
ml queue train.py --hypothesis "LR scaling..." --tags ablation
|
||||
ml outcome set run_abc --outcome validates --summary "Accuracy +2%"
|
||||
ml find --outcome validates --tag lr-test
|
||||
ml compare run_abc run_def
|
||||
ml privacy set run_abc --level team
|
||||
ml export run_abc --anonymize
|
||||
ml dataset verify /path/to/data
|
||||
```
|
||||
|
||||
## Phase 1 (V1) notes
|
||||
|
|
@ -121,9 +130,9 @@ FetchML includes optional C++ native libraries for performance. See `docs/src/na
|
|||
|
||||
Quick start:
|
||||
```bash
|
||||
make native-build # Build native libs
|
||||
make native-smoke # Run smoke test
|
||||
export FETCHML_NATIVE_LIBS=1 # Enable at runtime
|
||||
make native-build # Build native libs
|
||||
make native-smoke # Run smoke test
|
||||
go build -tags native_libs # Enable native libraries
|
||||
```
|
||||
|
||||
### Standard Build
|
||||
|
|
@ -150,6 +159,19 @@ See `docs/` for detailed guides:
|
|||
- `docs/src/zig-cli.md` – CLI reference
|
||||
- `docs/src/quick-start.md` – Full setup guide
|
||||
- `docs/src/deployment.md` – Production deployment
|
||||
- `docs/src/research-features.md` – Research workflow features (narrative capture, outcomes, search)
|
||||
- `docs/src/privacy-security.md` – Privacy levels, PII detection, anonymized export
|
||||
|
||||
## CLI Architecture (2026-02)
|
||||
|
||||
The Zig CLI has been refactored for improved maintainability:
|
||||
|
||||
- **Modular 3-layer architecture**: `core/` (foundation), `local/`/`server/` (mode-specific), `commands/` (routers)
|
||||
- **Unified context**: `core.context.Context` handles mode detection, output formatting, and dispatch
|
||||
- **Code reduction**: `experiment.zig` reduced from 836 to 348 lines (58% reduction)
|
||||
- **Bug fixes**: Resolved 15+ compilation errors across multiple commands
|
||||
|
||||
See `cli/README.md` for detailed architecture documentation.
|
||||
|
||||
## Source code
|
||||
|
||||
|
|
|
|||
149
SECURITY.md
149
SECURITY.md
|
|
@ -1,6 +1,153 @@
|
|||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Please report security vulnerabilities to security@fetchml.io.
|
||||
Do NOT open public issues for security bugs.
|
||||
|
||||
Response timeline:
|
||||
- Acknowledgment: within 48 hours
|
||||
- Initial assessment: within 5 days
|
||||
- Fix released: within 30 days (critical), 90 days (high)
|
||||
|
||||
## Security Features
|
||||
|
||||
FetchML implements defense-in-depth security for ML research systems:
|
||||
|
||||
### Authentication & Authorization
|
||||
- **Argon2id API Key Hashing**: Memory-hard hashing resists GPU cracking
|
||||
- **RBAC with Role Inheritance**: Granular permissions (admin, data_scientist, data_engineer, viewer, operator)
|
||||
- **Constant-time Comparison**: Prevents timing attacks on key validation
|
||||
|
||||
### Cryptographic Practices
|
||||
- **Ed25519 Manifest Signing**: Tamper detection for run manifests
|
||||
- **SHA-256 with Salt**: Legacy key support with migration path
|
||||
- **Secure Key Generation**: 256-bit entropy for all API keys
|
||||
|
||||
### Container Security
|
||||
- **Rootless Podman**: No privileged containers
|
||||
- **Capability Dropping**: `--cap-drop ALL` by default
|
||||
- **No New Privileges**: `no-new-privileges` security opt
|
||||
- **Read-only Root Filesystem**: Immutable base image
|
||||
|
||||
### Input Validation
|
||||
- **Path Traversal Prevention**: Canonical path validation
|
||||
- **Command Injection Protection**: Shell metacharacter filtering
|
||||
- **Length Limits**: Prevents DoS via oversized inputs
|
||||
|
||||
### Audit & Monitoring
|
||||
- **Structured Audit Logging**: JSON-formatted security events
|
||||
- **Hash-chained Logs**: Tamper-evident audit trail
|
||||
- **Anomaly Detection**: Brute force, privilege escalation alerts
|
||||
- **Security Metrics**: Prometheus integration
|
||||
|
||||
### Supply Chain
|
||||
- **Dependency Scanning**: gosec + govulncheck in CI
|
||||
- **No unsafe Package**: Prohibited in production code
|
||||
- **Manifest Signing**: Ed25519 signatures for integrity
|
||||
|
||||
## Supported Versions
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| 0.2.x | :white_check_mark: |
|
||||
| 0.1.x | :x: |
|
||||
|
||||
## Security Checklist (Pre-Release)
|
||||
|
||||
### Code Review
|
||||
- [ ] No hardcoded secrets
|
||||
- [ ] No `unsafe` usage without justification
|
||||
- [ ] All user inputs validated
|
||||
- [ ] All file paths canonicalized
|
||||
- [ ] No secrets in error messages
|
||||
|
||||
### Dependency Audit
|
||||
- [ ] `go mod verify` passes
|
||||
- [ ] `govulncheck` shows no vulnerabilities
|
||||
- [ ] All dependencies pinned
|
||||
- [ ] No unmaintained dependencies
|
||||
|
||||
### Container Security
|
||||
- [ ] No privileged containers
|
||||
- [ ] Rootless execution
|
||||
- [ ] Seccomp/AppArmor applied
|
||||
- [ ] Network isolation
|
||||
|
||||
### Cryptography
|
||||
- [ ] Argon2id for key hashing
|
||||
- [ ] Ed25519 for signing
|
||||
- [ ] TLS 1.3 only
|
||||
- [ ] No weak ciphers
|
||||
|
||||
### Testing
|
||||
- [ ] Security tests pass
|
||||
- [ ] Fuzz tests for parsers
|
||||
- [ ] Authentication bypass tested
|
||||
- [ ] Container escape tested
|
||||
|
||||
## Security Commands
|
||||
|
||||
```bash
|
||||
# Run security scan
|
||||
make security-scan
|
||||
|
||||
# Check for vulnerabilities
|
||||
govulncheck ./...
|
||||
|
||||
# Static analysis
|
||||
gosec ./...
|
||||
|
||||
# Check for unsafe usage
|
||||
grep -r "unsafe\." --include="*.go" ./internal ./cmd
|
||||
|
||||
# Build with sanitizers
|
||||
cd native && cmake -DENABLE_ASAN=ON .. && make
|
||||
```
|
||||
|
||||
## Threat Model
|
||||
|
||||
### Attack Surfaces
|
||||
1. **External API**: Researchers submitting malicious jobs
|
||||
2. **Container Runtime**: Escape to host system
|
||||
3. **Data Exfiltration**: Stealing datasets/models
|
||||
4. **Privilege Escalation**: Researcher → admin
|
||||
5. **Supply Chain**: Compromised dependencies
|
||||
6. **Secrets Leakage**: API keys in logs/errors
|
||||
|
||||
### Mitigations
|
||||
| Threat | Mitigation |
|
||||
|--------|------------|
|
||||
| Malicious Jobs | Input validation, container sandboxing, resource limits |
|
||||
| Container Escape | Rootless, no-new-privileges, seccomp, read-only root |
|
||||
| Data Exfiltration | Network policies, audit logging, rate limiting |
|
||||
| Privilege Escalation | RBAC, least privilege, anomaly detection |
|
||||
| Supply Chain | Dependency scanning, manifest signing, pinned versions |
|
||||
| Secrets Leakage | Log sanitization, secrets manager, memory clearing |
|
||||
|
||||
## Responsible Disclosure
|
||||
|
||||
We follow responsible disclosure practices:
|
||||
|
||||
1. **Report privately**: Email security@fetchml.io with details
|
||||
2. **Provide details**: Steps to reproduce, impact assessment
|
||||
3. **Allow time**: We need 30-90 days to fix before public disclosure
|
||||
4. **Acknowledgment**: We credit researchers who report valid issues
|
||||
|
||||
## Security Team
|
||||
|
||||
- security@fetchml.io - Security issues and questions
|
||||
- security-response@fetchml.io - Active incident response
|
||||
|
||||
---
|
||||
|
||||
*Last updated: 2026-02-19*
|
||||
|
||||
---
|
||||
|
||||
# Security Guide for Fetch ML Homelab
|
||||
|
||||
This guide covers security best practices for deploying Fetch ML in a homelab environment.
|
||||
*The following section covers security best practices for deploying Fetch ML in a homelab environment.*
|
||||
|
||||
## Quick Setup
|
||||
|
||||
|
|
|
|||
559
api/openapi.yaml
Normal file
559
api/openapi.yaml
Normal file
|
|
@ -0,0 +1,559 @@
|
|||
openapi: 3.0.3
|
||||
info:
|
||||
title: ML Worker API
|
||||
description: |
|
||||
API for managing ML experiment tasks and Jupyter services.
|
||||
|
||||
## Security
|
||||
All endpoints (except health checks) require API key authentication via the
|
||||
`X-API-Key` header. Rate limiting is enforced per API key.
|
||||
|
||||
## Error Handling
|
||||
Errors follow a consistent format with machine-readable codes and trace IDs:
|
||||
```json
|
||||
{
|
||||
"error": "Sanitized error message",
|
||||
"code": "ERROR_CODE",
|
||||
"trace_id": "uuid-for-support"
|
||||
}
|
||||
```
|
||||
version: 1.0.0
|
||||
contact:
|
||||
name: FetchML Support
|
||||
|
||||
servers:
|
||||
- url: http://localhost:9101
|
||||
description: Local development server
|
||||
- url: https://api.fetchml.example.com
|
||||
description: Production server
|
||||
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
|
||||
paths:
|
||||
/health:
|
||||
get:
|
||||
summary: Health check
|
||||
description: Returns server health status. No authentication required.
|
||||
security: []
|
||||
responses:
|
||||
'200':
|
||||
description: Server is healthy
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HealthResponse'
|
||||
|
||||
/v1/tasks:
|
||||
get:
|
||||
summary: List tasks
|
||||
description: List all tasks with optional filtering
|
||||
parameters:
|
||||
- name: status
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
enum: [queued, running, completed, failed]
|
||||
- name: limit
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 50
|
||||
maximum: 1000
|
||||
- name: offset
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 0
|
||||
responses:
|
||||
'200':
|
||||
description: List of tasks
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/TaskList'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
'429':
|
||||
$ref: '#/components/responses/RateLimited'
|
||||
|
||||
post:
|
||||
summary: Create task
|
||||
description: Submit a new ML experiment task
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreateTaskRequest'
|
||||
responses:
|
||||
'201':
|
||||
description: Task created successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Task'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
'422':
|
||||
$ref: '#/components/responses/ValidationError'
|
||||
'429':
|
||||
$ref: '#/components/responses/RateLimited'
|
||||
|
||||
/v1/tasks/{taskId}:
|
||||
get:
|
||||
summary: Get task details
|
||||
parameters:
|
||||
- name: taskId
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Task details
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Task'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
|
||||
delete:
|
||||
summary: Cancel/delete task
|
||||
parameters:
|
||||
- name: taskId
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'204':
|
||||
description: Task cancelled
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
|
||||
/v1/queue:
|
||||
get:
|
||||
summary: Queue status
|
||||
description: Get current queue statistics
|
||||
responses:
|
||||
'200':
|
||||
description: Queue statistics
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueueStats'
|
||||
|
||||
/v1/experiments:
|
||||
get:
|
||||
summary: List experiments
|
||||
description: List all experiments
|
||||
responses:
|
||||
'200':
|
||||
description: List of experiments
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Experiment'
|
||||
|
||||
post:
|
||||
summary: Create experiment
|
||||
description: Create a new experiment
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreateExperimentRequest'
|
||||
responses:
|
||||
'201':
|
||||
description: Experiment created
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Experiment'
|
||||
|
||||
/v1/jupyter/services:
|
||||
get:
|
||||
summary: List Jupyter services
|
||||
responses:
|
||||
'200':
|
||||
description: List of Jupyter services
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/JupyterService'
|
||||
|
||||
post:
|
||||
summary: Start Jupyter service
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/StartJupyterRequest'
|
||||
responses:
|
||||
'201':
|
||||
description: Jupyter service started
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JupyterService'
|
||||
|
||||
/v1/jupyter/services/{serviceId}:
|
||||
delete:
|
||||
summary: Stop Jupyter service
|
||||
parameters:
|
||||
- name: serviceId
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'204':
|
||||
description: Service stopped
|
||||
|
||||
/ws:
|
||||
get:
|
||||
summary: WebSocket connection
|
||||
description: |
|
||||
WebSocket endpoint for real-time task updates.
|
||||
|
||||
## Message Types
|
||||
- `task_update`: Task status changes
|
||||
- `task_complete`: Task finished
|
||||
- `ping`: Keep-alive (respond with `pong`)
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
responses:
|
||||
'101':
|
||||
description: WebSocket connection established
|
||||
|
||||
components:
|
||||
securitySchemes:
|
||||
ApiKeyAuth:
|
||||
type: apiKey
|
||||
in: header
|
||||
name: X-API-Key
|
||||
description: API key for authentication
|
||||
|
||||
schemas:
|
||||
HealthResponse:
|
||||
type: object
|
||||
properties:
|
||||
status:
|
||||
type: string
|
||||
enum: [healthy, degraded, unhealthy]
|
||||
version:
|
||||
type: string
|
||||
timestamp:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
Task:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: Unique task identifier
|
||||
job_name:
|
||||
type: string
|
||||
pattern: '^[a-zA-Z0-9_-]+$'
|
||||
maxLength: 64
|
||||
status:
|
||||
type: string
|
||||
enum: [queued, preparing, running, collecting, completed, failed]
|
||||
priority:
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: 10
|
||||
default: 5
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
started_at:
|
||||
type: string
|
||||
format: date-time
|
||||
ended_at:
|
||||
type: string
|
||||
format: date-time
|
||||
worker_id:
|
||||
type: string
|
||||
error:
|
||||
type: string
|
||||
output:
|
||||
type: string
|
||||
snapshot_id:
|
||||
type: string
|
||||
datasets:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
cpu:
|
||||
type: integer
|
||||
memory_gb:
|
||||
type: integer
|
||||
gpu:
|
||||
type: integer
|
||||
user_id:
|
||||
type: string
|
||||
retry_count:
|
||||
type: integer
|
||||
max_retries:
|
||||
type: integer
|
||||
|
||||
CreateTaskRequest:
|
||||
type: object
|
||||
required:
|
||||
- job_name
|
||||
properties:
|
||||
job_name:
|
||||
type: string
|
||||
pattern: '^[a-zA-Z0-9_-]+$'
|
||||
maxLength: 64
|
||||
description: Unique identifier for the job
|
||||
priority:
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: 10
|
||||
default: 5
|
||||
args:
|
||||
type: string
|
||||
description: Command-line arguments for the training script
|
||||
snapshot_id:
|
||||
type: string
|
||||
description: Reference to experiment snapshot
|
||||
datasets:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
dataset_specs:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/DatasetSpec'
|
||||
cpu:
|
||||
type: integer
|
||||
description: CPU cores requested
|
||||
memory_gb:
|
||||
type: integer
|
||||
description: Memory (GB) requested
|
||||
gpu:
|
||||
type: integer
|
||||
description: GPUs requested
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
|
||||
DatasetSpec:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
sha256:
|
||||
type: string
|
||||
mount_path:
|
||||
type: string
|
||||
|
||||
TaskList:
|
||||
type: object
|
||||
properties:
|
||||
tasks:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Task'
|
||||
total:
|
||||
type: integer
|
||||
limit:
|
||||
type: integer
|
||||
offset:
|
||||
type: integer
|
||||
|
||||
QueueStats:
|
||||
type: object
|
||||
properties:
|
||||
queued:
|
||||
type: integer
|
||||
description: Tasks waiting to run
|
||||
running:
|
||||
type: integer
|
||||
description: Tasks currently executing
|
||||
completed:
|
||||
type: integer
|
||||
description: Tasks completed today
|
||||
failed:
|
||||
type: integer
|
||||
description: Tasks failed today
|
||||
workers:
|
||||
type: integer
|
||||
description: Active workers
|
||||
|
||||
Experiment:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
commit_id:
|
||||
type: string
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
status:
|
||||
type: string
|
||||
enum: [active, archived, deleted]
|
||||
|
||||
CreateExperimentRequest:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
maxLength: 128
|
||||
description:
|
||||
type: string
|
||||
|
||||
JupyterService:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
enum: [starting, running, stopping, stopped, error]
|
||||
url:
|
||||
type: string
|
||||
format: uri
|
||||
token:
|
||||
type: string
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
StartJupyterRequest:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
workspace:
|
||||
type: string
|
||||
image:
|
||||
type: string
|
||||
default: jupyter/pytorch:latest
|
||||
|
||||
ErrorResponse:
|
||||
type: object
|
||||
required:
|
||||
- error
|
||||
- code
|
||||
- trace_id
|
||||
properties:
|
||||
error:
|
||||
type: string
|
||||
description: Sanitized error message
|
||||
code:
|
||||
type: string
|
||||
enum: [BAD_REQUEST, UNAUTHORIZED, FORBIDDEN, NOT_FOUND, CONFLICT, RATE_LIMITED, INTERNAL_ERROR, SERVICE_UNAVAILABLE, VALIDATION_ERROR]
|
||||
trace_id:
|
||||
type: string
|
||||
description: Support correlation ID
|
||||
|
||||
responses:
|
||||
BadRequest:
|
||||
description: Invalid request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
example:
|
||||
error: Invalid request format
|
||||
code: BAD_REQUEST
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
Unauthorized:
|
||||
description: Authentication required
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
example:
|
||||
error: Invalid or missing API key
|
||||
code: UNAUTHORIZED
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
Forbidden:
|
||||
description: Insufficient permissions
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
example:
|
||||
error: Insufficient permissions
|
||||
code: FORBIDDEN
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
example:
|
||||
error: Resource not found
|
||||
code: NOT_FOUND
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
ValidationError:
|
||||
description: Validation failed
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
example:
|
||||
error: Validation failed
|
||||
code: VALIDATION_ERROR
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
||||
RateLimited:
|
||||
description: Too many requests
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
example:
|
||||
error: Rate limit exceeded
|
||||
code: RATE_LIMITED
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
headers:
|
||||
Retry-After:
|
||||
schema:
|
||||
type: integer
|
||||
description: Seconds until rate limit resets
|
||||
|
||||
InternalError:
|
||||
description: Internal server error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
example:
|
||||
error: An error occurred
|
||||
code: INTERNAL_ERROR
|
||||
trace_id: a1b2c3d4-e5f6-7890-abcd-ef1234567890
|
||||
|
|
@ -22,9 +22,9 @@ RUN rm -rf native/build && cd native && mkdir -p build && cd build && \
|
|||
cmake .. -DCMAKE_BUILD_TYPE=Release && \
|
||||
make -j$(nproc)
|
||||
|
||||
# Build Go binaries with native libs
|
||||
RUN CGO_ENABLED=1 go build -o bin/api-server cmd/api-server/main.go && \
|
||||
CGO_ENABLED=1 go build -o bin/worker ./cmd/worker
|
||||
# Build Go binaries with native libs enabled via build tag
|
||||
RUN CGO_ENABLED=1 go build -tags native_libs -o bin/api-server cmd/api-server/main.go && \
|
||||
CGO_ENABLED=1 go build -tags native_libs -o bin/worker ./cmd/worker
|
||||
|
||||
# Final stage
|
||||
FROM alpine:3.19
|
||||
|
|
|
|||
46
cli/Makefile
46
cli/Makefile
|
|
@ -4,13 +4,17 @@ ZIG ?= zig
|
|||
BUILD_DIR ?= zig-out/bin
|
||||
BINARY := $(BUILD_DIR)/ml
|
||||
|
||||
.PHONY: all prod dev test build-rsync install clean help
|
||||
.PHONY: all prod dev test build-rsync build-sqlite install clean help
|
||||
|
||||
RSYNC_VERSION ?= 3.3.0
|
||||
RSYNC_SRC_BASE ?= https://download.samba.org/pub/rsync/src
|
||||
RSYNC_TARBALL ?= rsync-$(RSYNC_VERSION).tar.gz
|
||||
RSYNC_TARBALL_SHA256 ?=
|
||||
|
||||
SQLITE_VERSION ?= 3480000
|
||||
SQLITE_YEAR ?= 2025
|
||||
SQLITE_SRC_BASE ?= https://www.sqlite.org/2025
|
||||
|
||||
all: $(BINARY)
|
||||
|
||||
$(BUILD_DIR):
|
||||
|
|
@ -19,12 +23,33 @@ $(BUILD_DIR):
|
|||
$(BINARY): | $(BUILD_DIR)
|
||||
$(ZIG) build --release=small
|
||||
|
||||
prod: src/main.zig | $(BUILD_DIR)
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
PROD_DEPS := build-rsync build-sqlite
|
||||
else
|
||||
PROD_DEPS := build-sqlite
|
||||
endif
|
||||
|
||||
# Production build: optimized for speed with ReleaseFast + LTO
|
||||
prod: $(PROD_DEPS) | $(BUILD_DIR)
|
||||
$(ZIG) build --release=fast
|
||||
|
||||
# Tiny build: smallest binary with ReleaseSmall
|
||||
# Note: Requires SQLite amalgamation
|
||||
.PHONY: tiny
|
||||
tiny: $(PROD_DEPS) | $(BUILD_DIR)
|
||||
$(ZIG) build --release=small
|
||||
|
||||
# Development build: fast compilation + optimizations
|
||||
dev: src/main.zig | $(BUILD_DIR)
|
||||
$(ZIG) build --release=fast
|
||||
|
||||
# Debug build: fastest compilation, no optimizations
|
||||
.PHONY: debug
|
||||
debug: src/main.zig | $(BUILD_DIR)
|
||||
$(ZIG) build -Doptimize=Debug
|
||||
|
||||
test:
|
||||
$(ZIG) build test
|
||||
|
||||
|
|
@ -35,6 +60,12 @@ build-rsync:
|
|||
RSYNC_TARBALL_SHA256="$(RSYNC_TARBALL_SHA256)" \
|
||||
bash "$(CURDIR)/scripts/build_rsync.sh"
|
||||
|
||||
build-sqlite:
|
||||
@SQLITE_VERSION="${SQLITE_VERSION:-3480000}" \
|
||||
SQLITE_YEAR="${SQLITE_YEAR:-2025}" \
|
||||
SQLITE_SRC_BASE="$(SQLITE_SRC_BASE)" \
|
||||
bash "$(CURDIR)/scripts/build_sqlite.sh"
|
||||
|
||||
install: $(BINARY)
|
||||
install -d $(DESTDIR)/usr/local/bin
|
||||
install -m 0755 $(BINARY) $(DESTDIR)/usr/local/bin/ml
|
||||
|
|
@ -44,10 +75,13 @@ clean:
|
|||
|
||||
help:
|
||||
@echo "Targets:"
|
||||
@echo " all - build release-small binary (default)"
|
||||
@echo " prod - build production binary with ReleaseSmall"
|
||||
@echo " dev - build development binary with ReleaseFast"
|
||||
@echo " prod - build production binary with ReleaseFast + LTO (best performance)"
|
||||
@echo " tiny - build minimal binary with ReleaseSmall (smallest size)"
|
||||
@echo " dev - build development binary with ReleaseFast (quick builds)"
|
||||
@echo " debug - build debug binary with no optimizations (fastest compile)"
|
||||
@echo " all - build release-small binary (legacy, same as 'tiny')"
|
||||
@echo " test - run Zig unit tests"
|
||||
@echo " build-rsync - build pinned rsync from official source into src/assets (RSYNC_VERSION=... override)"
|
||||
@echo " build-rsync - build pinned rsync from official source into src/assets"
|
||||
@echo " build-sqlite - fetch SQLite amalgamation into src/assets"
|
||||
@echo " install - copy binary into /usr/local/bin"
|
||||
@echo " clean - remove build artifacts"
|
||||
219
cli/README.md
219
cli/README.md
|
|
@ -1,6 +1,43 @@
|
|||
# ML CLI
|
||||
|
||||
Fast CLI tool for managing ML experiments.
|
||||
Fast CLI tool for managing ML experiments. Supports both **local mode** (SQLite) and **server mode** (WebSocket).
|
||||
|
||||
## Architecture
|
||||
|
||||
The CLI follows a modular 3-layer architecture for maintainability:
|
||||
|
||||
```
|
||||
src/
|
||||
├── core/ # Shared foundation
|
||||
│ ├── context.zig # Execution context (allocator, config, mode dispatch)
|
||||
│ ├── output.zig # Unified JSON/text output helpers
|
||||
│ └── flags.zig # Common flag parsing
|
||||
├── local/ # Local mode operations (SQLite)
|
||||
│ └── experiment_ops.zig # Experiment CRUD for local DB
|
||||
├── server/ # Server mode operations (WebSocket)
|
||||
│ └── experiment_api.zig # Experiment API for remote server
|
||||
├── commands/ # Thin command routers
|
||||
│ ├── experiment.zig # ~100 lines (was 887)
|
||||
│ ├── queue.zig # Job submission
|
||||
│ └── queue/ # Queue submodules
|
||||
│ ├── parse.zig # Job template parsing
|
||||
│ ├── validate.zig # Validation logic
|
||||
│ └── submit.zig # Job submission
|
||||
└── utils/ # Utilities (21 files)
|
||||
```
|
||||
|
||||
### Mode Dispatch Pattern
|
||||
|
||||
Commands auto-detect local vs server mode using `core.context.Context`:
|
||||
|
||||
```zig
|
||||
var ctx = core.context.Context.init(allocator, cfg, flags.json);
|
||||
if (ctx.isLocal()) {
|
||||
return try local.experiment.list(ctx.allocator, ctx.json_output);
|
||||
} else {
|
||||
return try server.experiment.list(ctx.allocator, ctx.json_output);
|
||||
}
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
|
@ -8,65 +45,139 @@ Fast CLI tool for managing ML experiments.
|
|||
# 1. Build
|
||||
zig build
|
||||
|
||||
# 2. Setup configuration
|
||||
# 2. Initialize local tracking (creates fetch_ml.db)
|
||||
./zig-out/bin/ml init
|
||||
|
||||
# 3. Run experiment
|
||||
./zig-out/bin/ml sync ./my-experiment --queue
|
||||
# 3. Create experiment and run locally
|
||||
./zig-out/bin/ml experiment create --name "baseline"
|
||||
./zig-out/bin/ml run start --experiment <id> --name "run-1"
|
||||
./zig-out/bin/ml experiment log --run <id> --name loss --value 0.5
|
||||
./zig-out/bin/ml run finish --run <id>
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
- `ml init` - Setup configuration
|
||||
- `ml sync <path>` - Sync project to server
|
||||
- `ml queue <job1> [job2 ...] [--commit <id>] [--priority N] [--note <text>]` - Queue one or more jobs
|
||||
- `ml status` - Check system/queue status for your API key
|
||||
- `ml validate <commit_id> [--json] [--task <task_id>]` - Validate provenance + integrity for a commit or task (includes `run_manifest.json` consistency checks when validating by task)
|
||||
- `ml info <path|id> [--json] [--base <path>]` - Show run info from `run_manifest.json` (by path or by scanning `finished/failed/running/pending`)
|
||||
- `ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]` - Append a human annotation to `run_manifest.json`
|
||||
- `ml narrative set <path|run_id|task_id> [--hypothesis <text>] [--context <text>] [--intent <text>] [--expected-outcome <text>] [--parent-run <id>] [--experiment-group <text>] [--tags <csv>] [--base <path>] [--json]` - Patch the `narrative` field in `run_manifest.json`
|
||||
- `ml monitor` - Launch monitoring interface (TUI)
|
||||
- `ml cancel <job>` - Cancel a running/queued job you own
|
||||
- `ml prune --keep N` - Keep N recent experiments
|
||||
- `ml watch <path>` - Auto-sync directory
|
||||
- `ml experiment log|show|list|delete` - Manage experiments and metrics
|
||||
### Local Mode Commands (SQLite)
|
||||
|
||||
- `ml init` - Initialize local experiment tracking database
|
||||
- `ml experiment create --name <name>` - Create experiment locally
|
||||
- `ml experiment list` - List experiments from SQLite
|
||||
- `ml experiment log --run <id> --name <key> --value <val>` - Log metrics
|
||||
- `ml run start --experiment <id> [--name <name>]` - Start a run
|
||||
- `ml run finish --run <id>` - Mark run as finished
|
||||
- `ml run fail --run <id>` - Mark run as failed
|
||||
- `ml run list` - List all runs
|
||||
|
||||
### Server Mode Commands (WebSocket)
|
||||
|
||||
- `ml sync <path>` - Sync project to server
|
||||
- `ml queue <job1> [job2 ...] [--commit <id>] [--priority N] [--note <text>]` - Queue jobs
|
||||
- `ml status` - Check system/queue status
|
||||
- `ml validate <commit_id> [--json] [--task <task_id>]` - Validate provenance
|
||||
- `ml cancel <job>` - Cancel a running/queued job
|
||||
|
||||
### Shared Commands (Auto-detect Mode)
|
||||
|
||||
- `ml experiment log|show|list|delete` - Works in both local and server mode
|
||||
- `ml monitor` - Launch TUI (local SQLite or remote SSH)
|
||||
|
||||
Notes:
|
||||
|
||||
- `--json` mode is designed to be pipe-friendly: machine-readable JSON is emitted to stdout, while user-facing messages/errors go to stderr.
|
||||
- When running `ml validate --task <task_id>`, the server will try to locate the job's `run_manifest.json` under the configured base path (pending/running/finished/failed) and cross-check key fields (task id, commit id, deps, snapshot).
|
||||
- For tasks in `running`, `completed`, or `failed` state, a missing `run_manifest.json` is treated as a validation failure. For `queued` tasks, it is treated as a warning (the job may not have started yet).
|
||||
- Commands auto-detect mode from config (`sqlite://` vs `wss://`)
|
||||
- `--json` mode is designed to be pipe-friendly
|
||||
|
||||
### Experiment workflow (minimal)
|
||||
## Core Modules
|
||||
|
||||
- `ml sync ./my-experiment --queue`
|
||||
Syncs files, computes a unique commit ID for the directory, and queues a job.
|
||||
### `core.context`
|
||||
|
||||
- `ml queue my-job`
|
||||
Queues a job named `my-job`. If `--commit` is omitted, the CLI generates a random commit ID
|
||||
and records `(job_name, commit_id)` in `~/.ml/history.log` so you don't have to remember hashes.
|
||||
Provides unified execution context for all commands:
|
||||
|
||||
- `ml queue my-job --note "baseline run; lr=1e-3"`
|
||||
Adds a human-readable note to the run; it will be persisted into the run's `run_manifest.json` (under `metadata.note`).
|
||||
- **Mode detection**: Automatically detects local (SQLite) vs server (WebSocket) mode
|
||||
- **Output handling**: JSON vs text output based on `--json` flag
|
||||
- **Dispatch helpers**: `ctx.dispatch(local_fn, server_fn, args)` for mode-specific implementations
|
||||
|
||||
- `ml experiment list`
|
||||
Shows recent experiments from history with alias (job name) and commit ID.
|
||||
```zig
|
||||
const core = @import("../core.zig");
|
||||
|
||||
- `ml experiment delete <alias|commit>`
|
||||
Cancels a running/queued experiment by job name, full commit ID, or short commit prefix.
|
||||
pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
const cfg = try config.Config.load(allocator);
|
||||
var ctx = core.context.Context.init(allocator, cfg, flags.json);
|
||||
defer ctx.deinit();
|
||||
|
||||
// Dispatch to local or server implementation
|
||||
if (ctx.isLocal()) {
|
||||
return try local.experiment.list(ctx.allocator, ctx.json_output);
|
||||
} else {
|
||||
return try server.experiment.list(ctx.allocator, ctx.json_output);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `core.output`
|
||||
|
||||
Unified output helpers that respect `--json` flag:
|
||||
|
||||
```zig
|
||||
core.output.errorMsg("command", "Error message"); // JSON: {"success":false,...}
|
||||
core.output.success("command"); // JSON: {"success":true,...}
|
||||
core.output.successString("cmd", "key", "value"); // JSON with data
|
||||
core.output.info("Text output", .{}); // Text mode only
|
||||
core.output.usage("cmd", "usage string"); // Help text
|
||||
```
|
||||
|
||||
### `core.flags`
|
||||
|
||||
Common flag parsing utilities:
|
||||
|
||||
```zig
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var remaining = try core.flags.parseCommon(allocator, args, &flags);
|
||||
|
||||
// Check for subcommands
|
||||
if (core.flags.matchSubcommand(remaining.items, "list")) |sub_args| {
|
||||
return try executeList(ctx, sub_args);
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Create `~/.ml/config.toml`:
|
||||
### Local Mode (SQLite)
|
||||
|
||||
```toml
|
||||
# .fetchml/config.toml or ~/.ml/config.toml
|
||||
tracking_uri = "sqlite://./fetch_ml.db"
|
||||
artifact_path = "./experiments/"
|
||||
sync_uri = "" # Optional: server to sync with
|
||||
```
|
||||
|
||||
### Server Mode (WebSocket)
|
||||
|
||||
```toml
|
||||
# ~/.ml/config.toml
|
||||
worker_host = "worker.local"
|
||||
worker_user = "mluser"
|
||||
worker_user = "mluser"
|
||||
worker_base = "/data/ml-experiments"
|
||||
worker_port = 22
|
||||
api_key = "your-api-key"
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
### Development
|
||||
|
||||
```bash
|
||||
cd cli
|
||||
zig build
|
||||
```
|
||||
|
||||
### Production (requires SQLite in assets/)
|
||||
|
||||
```bash
|
||||
cd cli
|
||||
make build-sqlite # Fetch SQLite amalgamation
|
||||
zig build prod # Build with embedded SQLite
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
|
|
@ -77,7 +188,47 @@ make install
|
|||
cp zig-out/bin/ml /usr/local/bin/
|
||||
```
|
||||
|
||||
## Local/Server Module Pattern
|
||||
|
||||
Commands that work in both modes follow this structure:
|
||||
|
||||
```
|
||||
src/
|
||||
├── local.zig # Module index
|
||||
├── local/
|
||||
│ └── experiment_ops.zig # Local implementations
|
||||
├── server.zig # Module index
|
||||
└── server/
|
||||
└── experiment_api.zig # Server implementations
|
||||
```
|
||||
|
||||
### Adding a New Command
|
||||
|
||||
1. Create local implementation in `src/local/<name>_ops.zig`
|
||||
2. Create server implementation in `src/server/<name>_api.zig`
|
||||
3. Export from `src/local.zig` and `src/server.zig`
|
||||
4. Create thin router in `src/commands/<name>.zig` using `ctx.dispatch()`
|
||||
|
||||
## Maintainability Cleanup (2026-02)
|
||||
|
||||
Recent refactoring improved code organization:
|
||||
|
||||
| Metric | Before | After |
|
||||
|--------|--------|-------|
|
||||
| experiment.zig | 836 lines | 348 lines (58% reduction) |
|
||||
| queue.zig | 1203 lines | Modular structure |
|
||||
| Duplicate printUsage | 24 functions | 1 shared helper |
|
||||
| Mode dispatch logic | Inlined everywhere | `core.context.Context` |
|
||||
|
||||
### Key Improvements
|
||||
|
||||
1. **Core Modules**: Unified `core.output`, `core.flags`, `core.context` eliminate duplication
|
||||
2. **Mode Abstraction**: Local/server operations separated into dedicated modules
|
||||
3. **Queue Decomposition**: `queue/` submodules for parsing, validation, submission
|
||||
4. **Bug Fixes**: Resolved 15+ compilation errors in `narrative.zig`, `outcome.zig`, `annotate.zig`, etc.
|
||||
|
||||
## Need Help?
|
||||
|
||||
- `ml --help` - Show command help
|
||||
- `ml <command> --help` - Show command-specific help
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ pub fn build(b: *std.Build) void {
|
|||
const test_filter = b.option([]const u8, "test-filter", "Filter unit tests by name");
|
||||
_ = test_filter;
|
||||
|
||||
// Optimized release mode for size
|
||||
const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseSmall });
|
||||
// Standard optimize option - let user choose, default to ReleaseSmall for production
|
||||
const optimize = b.standardOptimizeOption(.{});
|
||||
|
||||
const options = b.addOptions();
|
||||
|
||||
|
|
@ -28,8 +28,8 @@ pub fn build(b: *std.Build) void {
|
|||
else => "unknown",
|
||||
};
|
||||
|
||||
const candidate_specific = b.fmt("src/assets/rsync_release_{s}_{s}.bin", .{ os_str, arch_str });
|
||||
const candidate_default = "src/assets/rsync_release.bin";
|
||||
const candidate_specific = b.fmt("src/assets/rsync/rsync_release_{s}_{s}.bin", .{ os_str, arch_str });
|
||||
const candidate_default = "src/assets/rsync/rsync_release.bin";
|
||||
|
||||
var selected_candidate: []const u8 = "";
|
||||
var has_rsync_release = false;
|
||||
|
|
@ -55,14 +55,36 @@ pub fn build(b: *std.Build) void {
|
|||
// rsync_embedded_binary.zig calls @embedFile() from cli/src/utils, so the embed path
|
||||
// must be relative to that directory.
|
||||
const selected_embed_path = if (has_rsync_release)
|
||||
b.fmt("../assets/{s}", .{std.fs.path.basename(selected_candidate)})
|
||||
b.fmt("../assets/rsync/{s}", .{std.fs.path.basename(selected_candidate)})
|
||||
else
|
||||
"";
|
||||
|
||||
options.addOption(bool, "has_rsync_release", has_rsync_release);
|
||||
options.addOption([]const u8, "rsync_release_path", selected_embed_path);
|
||||
|
||||
// CLI executable
|
||||
// Check for SQLite assets (platform-specific only, no generic fallback)
|
||||
const sqlite_dir = b.fmt("src/assets/sqlite_{s}_{s}", .{ os_str, arch_str });
|
||||
|
||||
var has_sqlite_release = false;
|
||||
var sqlite_release_path: []const u8 = "";
|
||||
|
||||
// Only check platform-specific directory
|
||||
if (std.fs.cwd().access(sqlite_dir, .{})) |_| {
|
||||
has_sqlite_release = true;
|
||||
sqlite_release_path = sqlite_dir;
|
||||
} else |_| {}
|
||||
|
||||
if (optimize == .ReleaseSmall and !has_sqlite_release) {
|
||||
std.debug.panic(
|
||||
"ReleaseSmall build requires SQLite amalgamation (detected optimize={s}). Run: make build-sqlite",
|
||||
.{@tagName(optimize)},
|
||||
);
|
||||
}
|
||||
|
||||
options.addOption(bool, "has_sqlite_release", has_sqlite_release);
|
||||
options.addOption([]const u8, "sqlite_release_path", sqlite_release_path);
|
||||
|
||||
// CLI executable - declared BEFORE SQLite setup so exe can be referenced
|
||||
const exe = b.addExecutable(.{
|
||||
.name = "ml",
|
||||
.root_module = b.createModule(.{
|
||||
|
|
@ -73,8 +95,43 @@ pub fn build(b: *std.Build) void {
|
|||
});
|
||||
|
||||
exe.root_module.strip = true;
|
||||
|
||||
exe.root_module.addOptions("build_options", options);
|
||||
// LTO disabled: requires LLD linker which may not be available
|
||||
// exe.want_lto = true;
|
||||
|
||||
// Link native dataset_hash library
|
||||
exe.linkLibC();
|
||||
exe.addLibraryPath(b.path("../native/build"));
|
||||
exe.linkSystemLibrary("dataset_hash");
|
||||
exe.addIncludePath(b.path("../native/dataset_hash"));
|
||||
|
||||
// SQLite setup: embedded for ReleaseSmall only, system lib for dev
|
||||
const use_embedded_sqlite = has_sqlite_release and (optimize == .ReleaseSmall);
|
||||
if (use_embedded_sqlite) {
|
||||
// Release: compile SQLite from downloaded amalgamation
|
||||
const sqlite_c_path = b.fmt("{s}/sqlite3.c", .{sqlite_release_path});
|
||||
exe.addCSourceFile(.{ .file = b.path(sqlite_c_path), .flags = &.{
|
||||
"-DSQLITE_ENABLE_FTS5",
|
||||
"-DSQLITE_ENABLE_JSON1",
|
||||
"-DSQLITE_THREADSAFE=1",
|
||||
"-DSQLITE_USE_URI",
|
||||
} });
|
||||
exe.addIncludePath(b.path(sqlite_release_path));
|
||||
|
||||
// Compile SQLite constants wrapper (needed for SQLITE_TRANSIENT workaround)
|
||||
exe.addCSourceFile(.{ .file = b.path("src/assets/sqlite/sqlite_constants.c"), .flags = &.{ "-Wall", "-Wextra" } });
|
||||
} else {
|
||||
// Dev: link against system SQLite
|
||||
exe.linkSystemLibrary("sqlite3");
|
||||
|
||||
// Add system include paths for sqlite3.h
|
||||
exe.addIncludePath(.{ .cwd_relative = "/usr/include" });
|
||||
exe.addIncludePath(.{ .cwd_relative = "/usr/local/include" });
|
||||
exe.addIncludePath(.{ .cwd_relative = "/opt/homebrew/include" });
|
||||
|
||||
// Compile SQLite constants wrapper with system headers
|
||||
exe.addCSourceFile(.{ .file = b.path("src/assets/sqlite/sqlite_constants.c"), .flags = &.{ "-Wall", "-Wextra" } });
|
||||
}
|
||||
|
||||
// Install the executable to zig-out/bin
|
||||
b.installArtifact(exe);
|
||||
|
|
@ -94,6 +151,13 @@ pub fn build(b: *std.Build) void {
|
|||
// Standard Zig test discovery - find all test files automatically
|
||||
const test_step = b.step("test", "Run unit tests");
|
||||
|
||||
// Safety check for release builds
|
||||
const safety_check_step = b.step("safety-check", "Verify ReleaseSafe mode is used for production");
|
||||
if (optimize != .ReleaseSafe and optimize != .Debug) {
|
||||
const warn_no_safe = b.addSystemCommand(&.{ "echo", "WARNING: Building without ReleaseSafe mode. Production builds should use -Doptimize=ReleaseSafe" });
|
||||
safety_check_step.dependOn(&warn_no_safe.step);
|
||||
}
|
||||
|
||||
// Test main executable
|
||||
const main_tests = b.addTest(.{
|
||||
.root_module = b.createModule(.{
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ if [[ "${os}" != "linux" ]]; then
|
|||
fi
|
||||
|
||||
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
out="${repo_root}/src/assets/rsync_release_linux_${arch}.bin"
|
||||
mkdir -p "${repo_root}/src/assets/rsync"
|
||||
out="${repo_root}/src/assets/rsync/rsync_release_${os}_${arch}.bin"
|
||||
|
||||
tmp="$(mktemp -d)"
|
||||
cleanup() { rm -rf "${tmp}"; }
|
||||
|
|
|
|||
50
cli/scripts/build_sqlite.sh
Normal file
50
cli/scripts/build_sqlite.sh
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
#!/bin/bash
|
||||
# Build/fetch SQLite amalgamation for embedding
|
||||
# Mirrors the rsync pattern: assets/sqlite_release_<os>_<arch>/
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SQLITE_VERSION="${SQLITE_VERSION:-3480000}" # 3.48.0
|
||||
SQLITE_YEAR="${SQLITE_YEAR:-2025}"
|
||||
SQLITE_SRC_BASE="${SQLITE_SRC_BASE:-https://www.sqlite.org/${SQLITE_YEAR}}"
|
||||
|
||||
os="$(uname -s | tr '[:upper:]' '[:lower:]')"
|
||||
arch="$(uname -m)"
|
||||
if [[ "${arch}" == "aarch64" || "${arch}" == "arm64" ]]; then arch="arm64"; fi
|
||||
if [[ "${arch}" == "x86_64" ]]; then arch="x86_64"; fi
|
||||
|
||||
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
out_dir="${repo_root}/src/assets/sqlite_${os}_${arch}"
|
||||
|
||||
echo "Fetching SQLite ${SQLITE_VERSION} for ${os}/${arch}..."
|
||||
|
||||
# Create platform-specific output directory
|
||||
mkdir -p "${out_dir}"
|
||||
|
||||
# Download if not present
|
||||
if [[ ! -f "${out_dir}/sqlite3.c" ]]; then
|
||||
echo "Fetching SQLite amalgamation..."
|
||||
tmp="$(mktemp -d)"
|
||||
cleanup() { rm -rf "${tmp}"; }
|
||||
trap cleanup EXIT
|
||||
|
||||
url="${SQLITE_SRC_BASE}/sqlite-amalgamation-${SQLITE_VERSION}.zip"
|
||||
echo "Fetching ${url}"
|
||||
curl -fsSL "${url}" -o "${tmp}/sqlite.zip"
|
||||
|
||||
unzip -q "${tmp}/sqlite.zip" -d "${tmp}"
|
||||
mv "${tmp}/sqlite-amalgamation-${SQLITE_VERSION}"/* "${out_dir}/"
|
||||
|
||||
echo "✓ SQLite fetched to ${out_dir}"
|
||||
else
|
||||
echo "✓ SQLite already present at ${out_dir}"
|
||||
fi
|
||||
|
||||
# Verify
|
||||
if [[ -f "${out_dir}/sqlite3.c" && -f "${out_dir}/sqlite3.h" ]]; then
|
||||
echo "✓ SQLite ready:"
|
||||
ls -lh "${out_dir}/sqlite3.c" "${out_dir}/sqlite3.h"
|
||||
else
|
||||
echo "Error: SQLite files not found in ${out_dir}"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -11,11 +11,11 @@ _ml_completions()
|
|||
cur="${COMP_WORDS[COMP_CWORD]}"
|
||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
||||
|
||||
# Global options
|
||||
global_opts="--help --verbose --quiet --monitor"
|
||||
|
||||
# Top-level subcommands
|
||||
cmds="init sync queue requeue status monitor cancel prune watch dataset experiment"
|
||||
cmds="init sync queue requeue status monitor cancel prune watch dataset experiment narrative outcome info logs annotate validate compare find export"
|
||||
|
||||
# Global options
|
||||
global_opts="--help --verbose --quiet --monitor --json"
|
||||
|
||||
# If completing the subcommand itself
|
||||
if [[ ${COMP_CWORD} -eq 1 ]]; then
|
||||
|
|
@ -41,7 +41,7 @@ _ml_completions()
|
|||
COMPREPLY=( $(compgen -d -- "${cur}") )
|
||||
;;
|
||||
queue)
|
||||
queue_opts="--commit --priority --cpu --memory --gpu --gpu-memory --snapshot-id --snapshot-sha256 --args -- ${global_opts}"
|
||||
queue_opts="--commit --priority --cpu --memory --gpu --gpu-memory --snapshot-id --snapshot-sha256 --args --note --hypothesis --context --intent --expected-outcome --experiment-group --tags --dry-run --validate --explain --force --mlflow --mlflow-uri --tensorboard --wandb-key --wandb-project --wandb-entity -- ${global_opts}"
|
||||
case "${prev}" in
|
||||
--priority)
|
||||
COMPREPLY=( $(compgen -W "0 1 2 3 4 5 6 7 8 9 10" -- "${cur}") )
|
||||
|
|
@ -52,7 +52,7 @@ _ml_completions()
|
|||
--gpu-memory)
|
||||
COMPREPLY=( $(compgen -W "4 8 16 24 32 48" -- "${cur}") )
|
||||
;;
|
||||
--commit|--snapshot-id|--snapshot-sha256|--args)
|
||||
--commit|--snapshot-id|--snapshot-sha256|--args|--note|--hypothesis|--context|--intent|--expected-outcome|--experiment-group|--tags|--mlflow-uri|--wandb-key|--wandb-project|--wandb-entity)
|
||||
# Free-form; no special completion
|
||||
;;
|
||||
*)
|
||||
|
|
@ -60,7 +60,7 @@ _ml_completions()
|
|||
COMPREPLY=( $(compgen -W "${queue_opts}" -- "${cur}") )
|
||||
else
|
||||
# Suggest common job names (static for now)
|
||||
COMPREPLY=( $(compgen -W "train evaluate deploy" -- "${cur}") )
|
||||
COMPREPLY=( $(compgen -W "train evaluate deploy test baseline" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
|
|
@ -114,16 +114,111 @@ _ml_completions()
|
|||
COMPREPLY=( $(compgen -d -- "${cur}") )
|
||||
;;
|
||||
dataset)
|
||||
COMPREPLY=( $(compgen -W "list upload download delete info search" -- "${cur}") )
|
||||
COMPREPLY=( $(compgen -W "list register info search verify" -- "${cur}") )
|
||||
;;
|
||||
experiment)
|
||||
COMPREPLY=( $(compgen -W "log show" -- "${cur}") )
|
||||
COMPREPLY=( $(compgen -W "log show list" -- "${cur}") )
|
||||
;;
|
||||
*)
|
||||
# Fallback to global options
|
||||
COMPREPLY=( $(compgen -W "${global_opts}" -- "${cur}") )
|
||||
narrative)
|
||||
narrative_opts="set --hypothesis --context --intent --expected-outcome --parent-run --experiment-group --tags --base --json --help"
|
||||
case "${prev}" in
|
||||
set)
|
||||
COMPREPLY=( $(compgen -W "${narrative_opts}" -- "${cur}") )
|
||||
;;
|
||||
--hypothesis|--context|--intent|--expected-outcome|--parent-run|--experiment-group|--tags|--base)
|
||||
# Free-form completion
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${narrative_opts}" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
outcome)
|
||||
outcome_opts="set --outcome --summary --learning --next-step --validation-status --surprise --base --json --help"
|
||||
case "${prev}" in
|
||||
set)
|
||||
COMPREPLY=( $(compgen -W "${outcome_opts}" -- "${cur}") )
|
||||
;;
|
||||
--outcome)
|
||||
COMPREPLY=( $(compgen -W "validates refutes inconclusive partial" -- "${cur}") )
|
||||
;;
|
||||
--validation-status)
|
||||
COMPREPLY=( $(compgen -W "validates refutes inconclusive" -- "${cur}") )
|
||||
;;
|
||||
--summary|--learning|--next-step|--surprise|--base)
|
||||
# Free-form completion
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${outcome_opts}" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
info|logs|annotate|validate)
|
||||
# These commands take a path/id and various options
|
||||
info_opts="--base --json --help"
|
||||
case "${prev}" in
|
||||
--base)
|
||||
COMPREPLY=( $(compgen -d -- "${cur}") )
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${info_opts}" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
compare)
|
||||
compare_opts="--json --csv --all --fields --help"
|
||||
case "${prev}" in
|
||||
--fields)
|
||||
COMPREPLY=( $(compgen -W "narrative,metrics,metadata,outcome" -- "${cur}") )
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${compare_opts}" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
find)
|
||||
find_opts="--json --csv --limit --tag --outcome --dataset --experiment-group --author --after --before --help"
|
||||
case "${prev}" in
|
||||
--limit)
|
||||
COMPREPLY=( $(compgen -W "10 20 50 100" -- "${cur}") )
|
||||
;;
|
||||
--outcome)
|
||||
COMPREPLY=( $(compgen -W "validates refutes inconclusive partial" -- "${cur}") )
|
||||
;;
|
||||
--tag|--dataset|--experiment-group|--author|--after|--before)
|
||||
# Free-form completion
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${find_opts}" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
export)
|
||||
export_opts="--bundle --anonymize --anonymize-level --base --json --help"
|
||||
case "${prev}" in
|
||||
--anonymize-level)
|
||||
COMPREPLY=( $(compgen -W "metadata-only full" -- "${cur}") )
|
||||
;;
|
||||
--bundle|--base)
|
||||
COMPREPLY=( $(compgen -f -- "${cur}") )
|
||||
;;
|
||||
*)
|
||||
if [[ "${cur}" == --* ]]; then
|
||||
COMPREPLY=( $(compgen -W "${export_opts}" -- "${cur}") )
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
esac
|
||||
|
||||
return 0
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,14 @@ _ml() {
|
|||
'prune:Prune old experiments'
|
||||
'watch:Watch directory for auto-sync'
|
||||
'dataset:Manage datasets'
|
||||
'experiment:Manage experiments'
|
||||
'find:Search experiments by tags/outcome/dataset'
|
||||
'export:Export experiment for sharing'
|
||||
'compare:Compare two runs narrative fields'
|
||||
'outcome:Set post-run outcome'
|
||||
'info:Show run info'
|
||||
'logs:Fetch job logs'
|
||||
'annotate:Add annotation to run'
|
||||
'validate:Validate provenance'
|
||||
)
|
||||
|
||||
local -a global_opts
|
||||
|
|
@ -26,6 +33,7 @@ _ml() {
|
|||
'--verbose:Enable verbose output'
|
||||
'--quiet:Suppress non-error output'
|
||||
'--monitor:Monitor progress of long-running operations'
|
||||
'--json:Output structured JSON'
|
||||
)
|
||||
|
||||
local curcontext="$curcontext" state line
|
||||
|
|
@ -53,7 +61,7 @@ _ml() {
|
|||
'--help[Show queue help]' \
|
||||
'--verbose[Enable verbose output]' \
|
||||
'--quiet[Suppress non-error output]' \
|
||||
'--monitor[Monitor progress]' \
|
||||
'--json[Output JSON]' \
|
||||
'--commit[Commit id (40-hex) or unique prefix (>=7)]:commit id:' \
|
||||
'--priority[Priority (0-255)]:priority:' \
|
||||
'--cpu[CPU cores]:cpu:' \
|
||||
|
|
@ -63,6 +71,23 @@ _ml() {
|
|||
'--snapshot-id[Snapshot id]:snapshot id:' \
|
||||
'--snapshot-sha256[Snapshot sha256]:snapshot sha256:' \
|
||||
'--args[Runner args string]:args:' \
|
||||
'--note[Human notes]:note:' \
|
||||
'--hypothesis[Research hypothesis]:hypothesis:' \
|
||||
'--context[Background context]:context:' \
|
||||
'--intent[What you are trying to accomplish]:intent:' \
|
||||
'--expected-outcome[What you expect to happen]:expected outcome:' \
|
||||
'--experiment-group[Group related experiments]:experiment group:' \
|
||||
'--tags[Comma-separated tags]:tags:' \
|
||||
'--dry-run[Show what would be queued]' \
|
||||
'--validate[Validate without queuing]' \
|
||||
'--explain[Explain what will happen]' \
|
||||
'--force[Queue even if duplicate exists]' \
|
||||
'--mlflow[Enable MLflow]' \
|
||||
'--mlflow-uri[MLflow tracking URI]:uri:' \
|
||||
'--tensorboard[Enable TensorBoard]' \
|
||||
'--wandb-key[Wandb API key]:key:' \
|
||||
'--wandb-project[Wandb project]:project:' \
|
||||
'--wandb-entity[Wandb entity]:entity:' \
|
||||
'1:job name:' \
|
||||
'*:args separator:(--)'
|
||||
;;
|
||||
|
|
@ -124,16 +149,85 @@ _ml() {
|
|||
dataset)
|
||||
_values 'dataset subcommand' \
|
||||
'list[list datasets]' \
|
||||
'upload[upload dataset]' \
|
||||
'download[download dataset]' \
|
||||
'delete[delete dataset]' \
|
||||
'register[register dataset]' \
|
||||
'info[dataset info]' \
|
||||
'search[search datasets]'
|
||||
'search[search datasets]' \
|
||||
'verify[verify dataset]'
|
||||
;;
|
||||
experiment)
|
||||
_values 'experiment subcommand' \
|
||||
'log[log metrics]' \
|
||||
'show[show experiment]'
|
||||
'show[show experiment]' \
|
||||
'list[list experiments]'
|
||||
;;
|
||||
narrative)
|
||||
_arguments -C \
|
||||
'1:subcommand:(set)' \
|
||||
'--hypothesis[Research hypothesis]:hypothesis:' \
|
||||
'--context[Background context]:context:' \
|
||||
'--intent[What you are trying to accomplish]:intent:' \
|
||||
'--expected-outcome[What you expect to happen]:expected outcome:' \
|
||||
'--parent-run[Parent run ID]:parent run:' \
|
||||
'--experiment-group[Group related experiments]:experiment group:' \
|
||||
'--tags[Comma-separated tags]:tags:' \
|
||||
'--base[Base path]:base:_directories' \
|
||||
'--json[Output JSON]' \
|
||||
'--help[Show help]'
|
||||
;;
|
||||
outcome)
|
||||
_arguments -C \
|
||||
'1:subcommand:(set)' \
|
||||
'--outcome[Outcome status]:outcome:(validates refutes inconclusive partial)' \
|
||||
'--summary[Summary of results]:summary:' \
|
||||
'--learning[A learning from this run]:learning:' \
|
||||
'--next-step[Suggested next step]:next step:' \
|
||||
'--validation-status[Did results validate hypothesis]:(validates refutes inconclusive)' \
|
||||
'--surprise[Unexpected finding]:surprise:' \
|
||||
'--base[Base path]:base:_directories' \
|
||||
'--json[Output JSON]' \
|
||||
'--help[Show help]'
|
||||
;;
|
||||
info|logs|annotate|validate)
|
||||
_arguments -C \
|
||||
'--base[Base path]:base:_directories' \
|
||||
'--json[Output JSON]' \
|
||||
'--help[Show help]' \
|
||||
'1:run id or path:'
|
||||
;;
|
||||
compare)
|
||||
_arguments -C \
|
||||
'--help[Show compare help]' \
|
||||
'--json[Output JSON]' \
|
||||
'--csv[Output CSV for spreadsheets]' \
|
||||
'--all[Show all fields]' \
|
||||
'--fields[Specify fields to compare]:fields:(narrative metrics metadata outcome)' \
|
||||
'1:run a:' \
|
||||
'2:run b:'
|
||||
;;
|
||||
find)
|
||||
_arguments -C \
|
||||
'--help[Show find help]' \
|
||||
'--json[Output JSON]' \
|
||||
'--csv[Output CSV for spreadsheets]' \
|
||||
'--limit[Max results]:limit:(10 20 50 100)' \
|
||||
'--tag[Filter by tag]:tag:' \
|
||||
'--outcome[Filter by outcome]:outcome:(validates refutes inconclusive partial)' \
|
||||
'--dataset[Filter by dataset]:dataset:' \
|
||||
'--experiment-group[Filter by group]:group:' \
|
||||
'--author[Filter by author]:author:' \
|
||||
'--after[After date]:date:' \
|
||||
'--before[Before date]:date:' \
|
||||
'::query:'
|
||||
;;
|
||||
export)
|
||||
_arguments -C \
|
||||
'--help[Show export help]' \
|
||||
'--json[Output JSON]' \
|
||||
'--bundle[Create bundle at path]:path:_files' \
|
||||
'--anonymize[Enable anonymization]' \
|
||||
'--anonymize-level[Anonymization level]:level:(metadata-only full)' \
|
||||
'--base[Base path]:base:_directories' \
|
||||
'1:run id or path:'
|
||||
;;
|
||||
*)
|
||||
_arguments -C "${global_opts[@]}"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
This directory contains rsync binaries for the ML CLI:
|
||||
|
||||
- `rsync_placeholder.bin` - Wrapper script for dev builds (calls system rsync)
|
||||
- `rsync_release_<os>_<arch>.bin` - Static rsync binary for release builds (not in repo)
|
||||
- `rsync/rsync_release.bin` - Static rsync binary for release builds (symlink to placeholder)
|
||||
- `rsync/rsync_release_<os>_<arch>.bin` - Downloaded static binary (not in repo)
|
||||
|
||||
## Build Modes
|
||||
|
||||
|
|
@ -16,8 +17,8 @@ This directory contains rsync binaries for the ML CLI:
|
|||
- Requires rsync installed on the system
|
||||
|
||||
### Release Builds (ReleaseSmall, ReleaseFast)
|
||||
- Uses `rsync_release_<os>_<arch>.bin` (static binary)
|
||||
- Fully self-contained, no dependencies
|
||||
- Uses `rsync/rsync_release_<os>_<arch>.bin` (downloaded static binary)
|
||||
- Falls back to `rsync/rsync_release.bin` (symlink to placeholder) if platform-specific not found
|
||||
- Results in ~450-650KB CLI binary
|
||||
- Works on any system without rsync installed
|
||||
|
||||
|
|
@ -44,30 +45,29 @@ cd rsync-3.3.0
|
|||
make
|
||||
|
||||
# Copy to assets (example)
|
||||
cp rsync ../fetch_ml/cli/src/assets/rsync_release_linux_x86_64.bin
|
||||
cp rsync ../fetch_ml/cli/src/assets/rsync/rsync_release_linux_x86_64.bin
|
||||
```
|
||||
|
||||
### Option 3: Use System Rsync (Temporary)
|
||||
|
||||
For testing release builds without a static binary:
|
||||
```bash
|
||||
cd cli/src/assets
|
||||
cp rsync_placeholder.bin rsync_release_linux_x86_64.bin
|
||||
cd cli/src/assets/rsync
|
||||
ln -sf rsync_placeholder.bin rsync_release.bin
|
||||
```
|
||||
|
||||
This will still use the wrapper, but allows builds to complete.
|
||||
|
||||
## Verification
|
||||
|
||||
After placing the appropriate `rsync_release_<os>_<arch>.bin`:
|
||||
After placing the appropriate `rsync/rsync_release_<os>_<arch>.bin`:
|
||||
|
||||
```bash
|
||||
|
||||
# Verify it's executable (example)
|
||||
file cli/src/assets/rsync_release_linux_x86_64.bin
|
||||
file cli/src/assets/rsync/rsync_release_linux_x86_64.bin
|
||||
|
||||
# Test it (example)
|
||||
./cli/src/assets/rsync_release_linux_x86_64.bin --version
|
||||
./cli/src/assets/rsync/rsync_release_linux_x86_64.bin --version
|
||||
|
||||
# Build release
|
||||
cd cli
|
||||
|
|
@ -79,7 +79,76 @@ ls -lh zig-out/prod/ml
|
|||
|
||||
## Notes
|
||||
|
||||
- `rsync_release.bin` is not tracked in git (add to .gitignore if needed)
|
||||
- `rsync/rsync_release_<os>_<arch>.bin` is not tracked in git
|
||||
- Different platforms need different static binaries
|
||||
- For cross-compilation, provide platform-specific binaries
|
||||
- The wrapper approach for dev builds is intentional for fast iteration
|
||||
|
||||
---
|
||||
|
||||
# SQLite Amalgamation Setup for Local Mode
|
||||
|
||||
## Overview
|
||||
|
||||
This directory contains SQLite source for FetchML local mode:
|
||||
|
||||
- `sqlite_<os>_<arch>/` - SQLite amalgamation for release builds (fetched, not in repo)
|
||||
- `sqlite3.c` - Single-file SQLite implementation
|
||||
- `sqlite3.h` - SQLite header file
|
||||
|
||||
## Build Modes
|
||||
|
||||
### Development/Debug Builds
|
||||
- Links against system SQLite library (`libsqlite3`)
|
||||
- Requires SQLite installed on system
|
||||
- Faster builds, smaller binary
|
||||
|
||||
### Release Builds (ReleaseSmall, ReleaseFast)
|
||||
- Compiles SQLite from downloaded amalgamation
|
||||
- Self-contained, no external dependencies
|
||||
- Works on any system without SQLite installed
|
||||
|
||||
## Preparing SQLite
|
||||
|
||||
### Option 1: Fetch from Official Source (recommended)
|
||||
|
||||
```bash
|
||||
cd cli
|
||||
make build-sqlite SQLITE_VERSION=3480000
|
||||
```
|
||||
|
||||
### Option 2: Download Yourself
|
||||
|
||||
```bash
|
||||
# Download official amalgamation
|
||||
SQLITE_VERSION=3480000
|
||||
SQLITE_YEAR=2025
|
||||
|
||||
cd cli
|
||||
make build-sqlite
|
||||
# Output: src/assets/sqlite_<os>_<arch>/
|
||||
```
|
||||
|
||||
## Verification
|
||||
|
||||
After fetching SQLite:
|
||||
|
||||
```bash
|
||||
# Verify files exist (example for darwin/arm64)
|
||||
ls -lh cli/src/assets/sqlite_darwin_arm64/sqlite3.c
|
||||
ls -lh cli/src/assets/sqlite_darwin_arm64/sqlite3.h
|
||||
|
||||
# Build CLI
|
||||
cd cli
|
||||
zig build prod
|
||||
|
||||
# Check binary works with local mode
|
||||
./zig-out/bin/ml init
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- `sqlite_<os>_<arch>/` directories are not tracked in git
|
||||
- Dev builds use system SQLite; release builds embed amalgamation
|
||||
- WAL mode is enabled for concurrent CLI writes and TUI reads
|
||||
- The amalgamation approach matches SQLite's recommended embedding pattern
|
||||
|
|
|
|||
2
cli/src/assets/rsync/rsync_placeholder.bin
Executable file
2
cli/src/assets/rsync/rsync_placeholder.bin
Executable file
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/sh
|
||||
exec /usr/bin/rsync "$@"
|
||||
1
cli/src/assets/rsync/rsync_release.bin
Symbolic link
1
cli/src/assets/rsync/rsync_release.bin
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
rsync_placeholder.bin
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Rsync wrapper for development builds
|
||||
# This calls the system's rsync instead of embedding a full binary
|
||||
# Keeps the dev binary small (152KB) while still functional
|
||||
|
||||
# Find rsync on the system
|
||||
RSYNC_PATH=$(which rsync 2>/dev/null || echo "/usr/bin/rsync")
|
||||
|
||||
if [ ! -x "$RSYNC_PATH" ]; then
|
||||
echo "Error: rsync not found on system. Please install rsync or use a release build with embedded rsync." >&2
|
||||
exit 127
|
||||
fi
|
||||
|
||||
# Pass all arguments to system rsync
|
||||
exec "$RSYNC_PATH" "$@"
|
||||
9
cli/src/assets/sqlite/sqlite_constants.c
Normal file
9
cli/src/assets/sqlite/sqlite_constants.c
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
// sqlite_constants.c - Simple C wrapper to export SQLITE_TRANSIENT
|
||||
// This works around Zig 0.15's C translation issue with SQLITE_TRANSIENT
|
||||
#include <sqlite3.h>
|
||||
|
||||
// Export SQLITE_TRANSIENT as a function that returns the value
|
||||
// This avoids the comptime C translation issue
|
||||
const void* fetchml_sqlite_transient(void) {
|
||||
return SQLITE_TRANSIENT;
|
||||
}
|
||||
|
|
@ -1,16 +1,17 @@
|
|||
pub const annotate = @import("commands/annotate.zig");
|
||||
pub const cancel = @import("commands/cancel.zig");
|
||||
pub const compare = @import("commands/compare.zig");
|
||||
pub const dataset = @import("commands/dataset.zig");
|
||||
pub const experiment = @import("commands/experiment.zig");
|
||||
pub const export_cmd = @import("commands/export_cmd.zig");
|
||||
pub const find = @import("commands/find.zig");
|
||||
pub const info = @import("commands/info.zig");
|
||||
pub const init = @import("commands/init.zig");
|
||||
pub const jupyter = @import("commands/jupyter.zig");
|
||||
pub const logs = @import("commands/logs.zig");
|
||||
pub const monitor = @import("commands/monitor.zig");
|
||||
pub const narrative = @import("commands/narrative.zig");
|
||||
pub const logs = @import("commands/log.zig");
|
||||
pub const prune = @import("commands/prune.zig");
|
||||
pub const queue = @import("commands/queue.zig");
|
||||
pub const requeue = @import("commands/requeue.zig");
|
||||
pub const run = @import("commands/run.zig");
|
||||
pub const status = @import("commands/status.zig");
|
||||
pub const sync = @import("commands/sync.zig");
|
||||
pub const validate = @import("commands/validate.zig");
|
||||
|
|
|
|||
|
|
@ -1,159 +1,143 @@
|
|||
const std = @import("std");
|
||||
const config = @import("../config.zig");
|
||||
const db = @import("../db.zig");
|
||||
const core = @import("../core.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
/// Annotate command - unified metadata annotation
|
||||
/// Usage:
|
||||
/// ml annotate <run_id> --text "Try lr=3e-4 next"
|
||||
/// ml annotate <run_id> --hypothesis "LR scaling helps"
|
||||
/// ml annotate <run_id> --outcome validates --confidence 0.9
|
||||
/// ml annotate <run_id> --privacy private
|
||||
pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var command_args = try core.flags.parseCommon(allocator, args, &flags);
|
||||
defer command_args.deinit(allocator);
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
if (flags.help) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
if (command_args.items.len < 1) {
|
||||
std.log.err("Usage: ml annotate <run_id> [options]", .{});
|
||||
return error.MissingArgument;
|
||||
}
|
||||
|
||||
const target = args[0];
|
||||
const run_id = command_args.items[0];
|
||||
|
||||
var author: []const u8 = "";
|
||||
var note: ?[]const u8 = null;
|
||||
var base_override: ?[]const u8 = null;
|
||||
var json_mode: bool = false;
|
||||
// Parse metadata options
|
||||
const text = core.flags.parseKVFlag(command_args.items, "text");
|
||||
const hypothesis = core.flags.parseKVFlag(command_args.items, "hypothesis");
|
||||
const outcome = core.flags.parseKVFlag(command_args.items, "outcome");
|
||||
const confidence = core.flags.parseKVFlag(command_args.items, "confidence");
|
||||
const privacy = core.flags.parseKVFlag(command_args.items, "privacy");
|
||||
const author = core.flags.parseKVFlag(command_args.items, "author");
|
||||
|
||||
var i: usize = 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const a = args[i];
|
||||
if (std.mem.eql(u8, a, "--author")) {
|
||||
if (i + 1 >= args.len) {
|
||||
colors.printError("Missing value for --author\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
author = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--note")) {
|
||||
if (i + 1 >= args.len) {
|
||||
colors.printError("Missing value for --note\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
note = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--base")) {
|
||||
if (i + 1 >= args.len) {
|
||||
colors.printError("Missing value for --base\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
base_override = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else if (std.mem.startsWith(u8, a, "--")) {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
} else {
|
||||
colors.printError("Unexpected argument: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
// Check that at least one option is provided
|
||||
if (text == null and hypothesis == null and outcome == null and privacy == null) {
|
||||
std.log.err("No metadata provided. Use --text, --hypothesis, --outcome, or --privacy", .{});
|
||||
return error.MissingMetadata;
|
||||
}
|
||||
|
||||
if (note == null or std.mem.trim(u8, note.?, " \t\r\n").len == 0) {
|
||||
colors.printError("--note is required\n", .{});
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const resolved_base = base_override orelse cfg.worker_base;
|
||||
// Get DB path
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
|
||||
.{target},
|
||||
);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
const job_name = try manifest.readJobNameFromManifest(allocator, manifest_path);
|
||||
defer allocator.free(job_name);
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendAnnotateRun(job_name, author, note.?, api_key_hash);
|
||||
|
||||
if (json_mode) {
|
||||
const msg = try client.receiveMessage(allocator);
|
||||
defer allocator.free(msg);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(msg, allocator) catch {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{msg});
|
||||
return error.InvalidPacket;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
const Result = struct {
|
||||
ok: bool,
|
||||
job_name: []const u8,
|
||||
message: []const u8,
|
||||
error_code: ?u8 = null,
|
||||
error_message: ?[]const u8 = null,
|
||||
details: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
var out = io.stdoutWriter();
|
||||
if (packet.packet_type == .error_packet) {
|
||||
const res = Result{
|
||||
.ok = false,
|
||||
.job_name = job_name,
|
||||
.message = "",
|
||||
.error_code = @intFromEnum(packet.error_code.?),
|
||||
.error_message = packet.error_message orelse "",
|
||||
.details = packet.error_details orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return error.CommandFailed;
|
||||
}
|
||||
|
||||
const res = Result{
|
||||
.ok = true,
|
||||
.job_name = job_name,
|
||||
.message = packet.success_message orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return;
|
||||
// Verify run exists
|
||||
const check_sql = "SELECT 1 FROM ml_runs WHERE run_id = ?;";
|
||||
const check_stmt = try database.prepare(check_sql);
|
||||
defer db.DB.finalize(check_stmt);
|
||||
try db.DB.bindText(check_stmt, 1, run_id);
|
||||
const has_row = try db.DB.step(check_stmt);
|
||||
if (!has_row) {
|
||||
std.log.err("Run not found: {s}", .{run_id});
|
||||
return error.RunNotFound;
|
||||
}
|
||||
|
||||
try client.receiveAndHandleResponse(allocator, "Annotate");
|
||||
// Add text note as a tag
|
||||
if (text) |t| {
|
||||
try addTag(allocator, &database, run_id, "note", t, author);
|
||||
}
|
||||
|
||||
colors.printSuccess("Annotation added\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
// Add hypothesis
|
||||
if (hypothesis) |h| {
|
||||
try addTag(allocator, &database, run_id, "hypothesis", h, author);
|
||||
}
|
||||
|
||||
// Add outcome
|
||||
if (outcome) |o| {
|
||||
try addTag(allocator, &database, run_id, "outcome", o, author);
|
||||
if (confidence) |c| {
|
||||
try addTag(allocator, &database, run_id, "confidence", c, author);
|
||||
}
|
||||
}
|
||||
|
||||
// Add privacy level
|
||||
if (privacy) |p| {
|
||||
try addTag(allocator, &database, run_id, "privacy", p, author);
|
||||
}
|
||||
|
||||
// Checkpoint WAL
|
||||
database.checkpointOnExit();
|
||||
|
||||
if (flags.json) {
|
||||
std.debug.print("{{\"success\":true,\"run_id\":\"{s}\",\"action\":\"note_added\"}}\n", .{run_id});
|
||||
} else {
|
||||
colors.printSuccess("✓ Added note to run {s}\n", .{run_id[0..8]});
|
||||
}
|
||||
}
|
||||
|
||||
fn addTag(
|
||||
allocator: std.mem.Allocator,
|
||||
database: *db.DB,
|
||||
run_id: []const u8,
|
||||
key: []const u8,
|
||||
value: []const u8,
|
||||
author: ?[]const u8,
|
||||
) !void {
|
||||
const full_value = if (author) |a|
|
||||
try std.fmt.allocPrint(allocator, "{s} (by {s})", .{ value, a })
|
||||
else
|
||||
try allocator.dupe(u8, value);
|
||||
defer allocator.free(full_value);
|
||||
|
||||
const sql = "INSERT INTO ml_tags (run_id, key, value) VALUES (?, ?, ?);";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
try db.DB.bindText(stmt, 2, key);
|
||||
try db.DB.bindText(stmt, 3, full_value);
|
||||
_ = try db.DB.step(stmt);
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml annotate <path|run_id|task_id> --note <text> [--author <name>] [--base <path>] [--json]\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml annotate 8b3f... --note \"Try lr=3e-4 next\"\n", .{});
|
||||
colors.printInfo(" ml annotate ./finished/job-123 --note \"Baseline looks stable\" --author alice\n", .{});
|
||||
std.debug.print("Usage: ml annotate <run_id> [options]\n\n", .{});
|
||||
std.debug.print("Add metadata annotations to a run.\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --text <string> Free-form annotation\n", .{});
|
||||
std.debug.print(" --hypothesis <string> Research hypothesis\n", .{});
|
||||
std.debug.print(" --outcome <status> Outcome: validates/refutes/inconclusive\n", .{});
|
||||
std.debug.print(" --confidence <0-1> Confidence in outcome\n", .{});
|
||||
std.debug.print(" --privacy <level> Privacy: private/team/public\n", .{});
|
||||
std.debug.print(" --author <name> Author of the annotation\n", .{});
|
||||
std.debug.print(" --help, -h Show this help\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n\n", .{});
|
||||
std.debug.print("Examples:\n", .{});
|
||||
std.debug.print(" ml annotate abc123 --text \"Try lr=3e-4 next\"\n", .{});
|
||||
std.debug.print(" ml annotate abc123 --hypothesis \"LR scaling helps\"\n", .{});
|
||||
std.debug.print(" ml annotate abc123 --outcome validates --confidence 0.9\n", .{});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,126 +1,212 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const config = @import("../config.zig");
|
||||
const db = @import("../db.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const auth = @import("../utils/auth.zig");
|
||||
|
||||
pub const CancelOptions = struct {
|
||||
force: bool = false,
|
||||
json: bool = false,
|
||||
};
|
||||
const core = @import("../core.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var options = CancelOptions{};
|
||||
var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
|
||||
colors.printError("Failed to allocate job list: {}\n", .{err});
|
||||
return err;
|
||||
};
|
||||
defer job_names.deinit(allocator);
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var force = false;
|
||||
var targets: std.ArrayList([]const u8) = .empty;
|
||||
defer targets.deinit(allocator);
|
||||
|
||||
// Parse arguments for flags and job names
|
||||
// Parse arguments
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
|
||||
if (std.mem.eql(u8, arg, "--force")) {
|
||||
options.force = true;
|
||||
force = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
} else if (std.mem.startsWith(u8, arg, "--help")) {
|
||||
try printUsage();
|
||||
return;
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
return printUsage();
|
||||
} else if (std.mem.startsWith(u8, arg, "--")) {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
try printUsage();
|
||||
core.output.errorMsg("cancel", "Unknown option");
|
||||
return error.InvalidArgs;
|
||||
} else {
|
||||
// This is a job name
|
||||
try job_names.append(allocator, arg);
|
||||
try targets.append(allocator, arg);
|
||||
}
|
||||
}
|
||||
|
||||
if (job_names.items.len == 0) {
|
||||
colors.printError("No job names specified\n", .{});
|
||||
try printUsage();
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
if (targets.items.len == 0) {
|
||||
core.output.errorMsg("cancel", "No run_id specified");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Authenticate with server to get user context
|
||||
var user_context = try auth.authenticateUser(allocator, config);
|
||||
defer user_context.deinit();
|
||||
// Detect mode
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
if (mode_result.warning) |w| {
|
||||
std.log.warn("{s}", .{w});
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Connect to WebSocket and send cancel messages
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Process each job
|
||||
var success_count: usize = 0;
|
||||
var failed_jobs = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
|
||||
colors.printError("Failed to allocate failed jobs list: {}\n", .{err});
|
||||
return err;
|
||||
};
|
||||
defer failed_jobs.deinit(allocator);
|
||||
var failed_count: usize = 0;
|
||||
|
||||
for (job_names.items, 0..) |job_name, index| {
|
||||
if (!options.json) {
|
||||
colors.printInfo("Processing job {d}/{d}: {s}\n", .{ index + 1, job_names.items.len, job_name });
|
||||
}
|
||||
|
||||
cancelSingleJob(allocator, &client, user_context, job_name, options, api_key_hash) catch |err| {
|
||||
colors.printError("Failed to cancel job '{s}': {}\n", .{ job_name, err });
|
||||
failed_jobs.append(allocator, job_name) catch |append_err| {
|
||||
colors.printError("Failed to track failed job: {}\n", .{append_err});
|
||||
for (targets.items) |target| {
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
// Local mode: kill by PID
|
||||
cancelLocal(allocator, target, force, flags.json) catch |err| {
|
||||
if (!flags.json) {
|
||||
colors.printError("Failed to cancel '{s}': {}\n", .{ target, err });
|
||||
}
|
||||
failed_count += 1;
|
||||
continue;
|
||||
};
|
||||
continue;
|
||||
};
|
||||
|
||||
} else {
|
||||
// Online mode: cancel on server
|
||||
cancelServer(allocator, target, force, flags.json, cfg) catch |err| {
|
||||
if (!flags.json) {
|
||||
colors.printError("Failed to cancel '{s}': {}\n", .{ target, err });
|
||||
}
|
||||
failed_count += 1;
|
||||
continue;
|
||||
};
|
||||
}
|
||||
success_count += 1;
|
||||
}
|
||||
|
||||
// Show summary
|
||||
if (!options.json) {
|
||||
colors.printInfo("\nCancel Summary:\n", .{});
|
||||
colors.printSuccess("Successfully canceled {d} job(s)\n", .{success_count});
|
||||
if (failed_jobs.items.len > 0) {
|
||||
colors.printError("Failed to cancel {d} job(s):\n", .{failed_jobs.items.len});
|
||||
for (failed_jobs.items) |failed_job| {
|
||||
colors.printError(" - {s}\n", .{failed_job});
|
||||
}
|
||||
if (flags.json) {
|
||||
std.debug.print("{{\"success\":true,\"canceled\":{d},\"failed\":{d}}}\n", .{ success_count, failed_count });
|
||||
} else {
|
||||
colors.printSuccess("Canceled {d} run(s)\n", .{success_count});
|
||||
if (failed_count > 0) {
|
||||
colors.printError("Failed to cancel {d} run(s)\n", .{failed_count});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cancelSingleJob(allocator: std.mem.Allocator, client: *ws.Client, user_context: auth.UserContext, job_name: []const u8, options: CancelOptions, api_key_hash: []const u8) !void {
|
||||
/// Cancel local run by PID
|
||||
fn cancelLocal(allocator: std.mem.Allocator, run_id: []const u8, force: bool, json: bool) !void {
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Get DB path
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
// Look up PID
|
||||
const sql = "SELECT pid FROM ml_runs WHERE run_id = ? AND status = 'RUNNING';";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
|
||||
const has_row = try db.DB.step(stmt);
|
||||
if (!has_row) {
|
||||
return error.RunNotFoundOrNotRunning;
|
||||
}
|
||||
|
||||
const pid = db.DB.columnInt64(stmt, 0);
|
||||
if (pid == 0) {
|
||||
return error.NoPIDAvailable;
|
||||
}
|
||||
|
||||
// Send SIGTERM first
|
||||
std.posix.kill(@intCast(pid), std.posix.SIG.TERM) catch |err| {
|
||||
if (err == error.ProcessNotFound) {
|
||||
// Process already gone
|
||||
} else {
|
||||
return err;
|
||||
}
|
||||
};
|
||||
|
||||
if (!force) {
|
||||
// Wait 5 seconds for graceful termination
|
||||
std.Thread.sleep(5 * std.time.ns_per_s);
|
||||
}
|
||||
|
||||
// Check if still running, send SIGKILL if needed
|
||||
if (force or isProcessRunning(@intCast(pid))) {
|
||||
std.posix.kill(@intCast(pid), std.posix.SIG.KILL) catch |err| {
|
||||
if (err != error.ProcessNotFound) {
|
||||
return err;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Update run status
|
||||
const update_sql = "UPDATE ml_runs SET status = 'CANCELLED', pid = NULL WHERE run_id = ?;";
|
||||
const update_stmt = try database.prepare(update_sql);
|
||||
defer db.DB.finalize(update_stmt);
|
||||
try db.DB.bindText(update_stmt, 1, run_id);
|
||||
_ = try db.DB.step(update_stmt);
|
||||
|
||||
// Update manifest
|
||||
const artifact_path = try std.fs.path.join(allocator, &[_][]const u8{
|
||||
cfg.artifact_path,
|
||||
if (cfg.experiment) |exp| exp.name else "default",
|
||||
run_id,
|
||||
"run_manifest.json",
|
||||
});
|
||||
defer allocator.free(artifact_path);
|
||||
|
||||
manifest_lib.updateManifestStatus(artifact_path, "CANCELLED", null, allocator) catch {};
|
||||
|
||||
// Checkpoint
|
||||
database.checkpointOnExit();
|
||||
|
||||
if (!json) {
|
||||
colors.printSuccess("✓ Canceled run {s}\n", .{run_id[0..8]});
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if process is still running
|
||||
fn isProcessRunning(pid: i32) bool {
|
||||
const result = std.posix.kill(pid, 0);
|
||||
return if (result) |_| true else |err| err == error.PermissionDenied;
|
||||
}
|
||||
|
||||
/// Cancel server job
|
||||
fn cancelServer(allocator: std.mem.Allocator, job_name: []const u8, force: bool, json: bool, cfg: config.Config) !void {
|
||||
_ = force;
|
||||
_ = json;
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendCancelJob(job_name, api_key_hash);
|
||||
|
||||
// Receive structured response with user context
|
||||
try client.receiveAndHandleCancelResponse(allocator, user_context, job_name, options);
|
||||
// Wait for acknowledgment
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
// Parse response (simplified)
|
||||
if (std.mem.indexOf(u8, message, "error") != null) {
|
||||
return error.ServerCancelFailed;
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml cancel [options] <job-name> [<job-name> ...]\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --force Force cancel even if job is running\n", .{});
|
||||
colors.printInfo("Usage: ml cancel [options] <run-id> [<run-id> ...]\n", .{});
|
||||
colors.printInfo("\nCancel a local run (kill process) or server job.\n\n", .{});
|
||||
colors.printInfo("Options:\n", .{});
|
||||
colors.printInfo(" --force Force cancel (SIGKILL immediately)\n", .{});
|
||||
colors.printInfo(" --json Output structured JSON\n", .{});
|
||||
colors.printInfo(" --help Show this help message\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help message\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml cancel job1 # Cancel single job\n", .{});
|
||||
colors.printInfo(" ml cancel job1 job2 job3 # Cancel multiple jobs\n", .{});
|
||||
colors.printInfo(" ml cancel --force job1 # Force cancel running job\n", .{});
|
||||
colors.printInfo(" ml cancel --json job1 # Cancel job with JSON output\n", .{});
|
||||
colors.printInfo(" ml cancel --force --json job1 job2 # Force cancel with JSON output\n", .{});
|
||||
colors.printInfo(" ml cancel abc123 # Cancel local run by run_id\n", .{});
|
||||
colors.printInfo(" ml cancel --force abc123 # Force cancel\n", .{});
|
||||
}
|
||||
|
|
|
|||
516
cli/src/commands/compare.zig
Normal file
516
cli/src/commands/compare.zig
Normal file
|
|
@ -0,0 +1,516 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub const CompareOptions = struct {
|
||||
json: bool = false,
|
||||
csv: bool = false,
|
||||
all_fields: bool = false,
|
||||
fields: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len < 2) {
|
||||
core.output.usage("compare", "Expected two run IDs");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
const run_a = argv[0];
|
||||
const run_b = argv[1];
|
||||
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var csv: bool = false;
|
||||
var all_fields: bool = false;
|
||||
var fields: ?[]const u8 = null;
|
||||
|
||||
var i: usize = 2;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const arg = argv[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--csv")) {
|
||||
csv = true;
|
||||
} else if (std.mem.eql(u8, arg, "--all")) {
|
||||
all_fields = true;
|
||||
} else if (std.mem.eql(u8, arg, "--fields") and i + 1 < argv.len) {
|
||||
fields = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
return printUsage();
|
||||
} else {
|
||||
core.output.errorMsg("compare", "Unknown option");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
// Fetch both runs
|
||||
colors.printInfo("Fetching run {s}...\n", .{run_a});
|
||||
var client_a = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client_a.close();
|
||||
|
||||
// Try to get experiment info for run A
|
||||
try client_a.sendGetExperiment(run_a, api_key_hash);
|
||||
const msg_a = try client_a.receiveMessage(allocator);
|
||||
defer allocator.free(msg_a);
|
||||
|
||||
colors.printInfo("Fetching run {s}...\n", .{run_b});
|
||||
var client_b = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client_b.close();
|
||||
|
||||
try client_b.sendGetExperiment(run_b, api_key_hash);
|
||||
const msg_b = try client_b.receiveMessage(allocator);
|
||||
defer allocator.free(msg_b);
|
||||
|
||||
// Parse responses
|
||||
const parsed_a = std.json.parseFromSlice(std.json.Value, allocator, msg_a, .{}) catch {
|
||||
colors.printError("Failed to parse response for {s}\n", .{run_a});
|
||||
return error.InvalidResponse;
|
||||
};
|
||||
defer parsed_a.deinit();
|
||||
|
||||
const parsed_b = std.json.parseFromSlice(std.json.Value, allocator, msg_b, .{}) catch {
|
||||
colors.printError("Failed to parse response for {s}\n", .{run_b});
|
||||
return error.InvalidResponse;
|
||||
};
|
||||
defer parsed_b.deinit();
|
||||
|
||||
const root_a = parsed_a.value.object;
|
||||
const root_b = parsed_b.value.object;
|
||||
|
||||
// Check for errors
|
||||
if (root_a.get("error")) |err_a| {
|
||||
colors.printError("Error fetching {s}: {s}\n", .{ run_a, err_a.string });
|
||||
return error.ServerError;
|
||||
}
|
||||
if (root_b.get("error")) |err_b| {
|
||||
colors.printError("Error fetching {s}: {s}\n", .{ run_b, err_b.string });
|
||||
return error.ServerError;
|
||||
}
|
||||
|
||||
if (flags.json) {
|
||||
try outputJsonComparison(allocator, root_a, root_b, run_a, run_b);
|
||||
} else {
|
||||
try outputHumanComparison(root_a, root_b, run_a, run_b, all_fields);
|
||||
}
|
||||
}
|
||||
|
||||
fn outputHumanComparison(
|
||||
root_a: std.json.ObjectMap,
|
||||
root_b: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
all_fields: bool,
|
||||
) !void {
|
||||
colors.printInfo("\n=== Comparison: {s} vs {s} ===\n\n", .{ run_a, run_b });
|
||||
|
||||
// Common fields
|
||||
const job_name_a = jsonGetString(root_a, "job_name") orelse "unknown";
|
||||
const job_name_b = jsonGetString(root_b, "job_name") orelse "unknown";
|
||||
|
||||
if (!std.mem.eql(u8, job_name_a, job_name_b)) {
|
||||
colors.printWarning("Job names differ:\n", .{});
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, job_name_a });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, job_name_b });
|
||||
} else {
|
||||
colors.printInfo("Job Name: {s}\n", .{job_name_a});
|
||||
}
|
||||
|
||||
// Experiment group
|
||||
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
|
||||
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
|
||||
if (group_a.len > 0 or group_b.len > 0) {
|
||||
colors.printInfo("\nExperiment Group:\n", .{});
|
||||
if (std.mem.eql(u8, group_a, group_b)) {
|
||||
colors.printInfo(" Both: {s}\n", .{group_a});
|
||||
} else {
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, group_a });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, group_b });
|
||||
}
|
||||
}
|
||||
|
||||
// Narrative fields
|
||||
const narrative_a = root_a.get("narrative");
|
||||
const narrative_b = root_b.get("narrative");
|
||||
|
||||
if (narrative_a != null or narrative_b != null) {
|
||||
colors.printInfo("\n--- Narrative ---\n", .{});
|
||||
|
||||
if (narrative_a) |na| {
|
||||
if (narrative_b) |nb| {
|
||||
if (na == .object and nb == .object) {
|
||||
try compareNarrativeFields(na.object, nb.object, run_a, run_b);
|
||||
}
|
||||
} else {
|
||||
colors.printInfo(" {s} has narrative, {s} does not\n", .{ run_a, run_b });
|
||||
}
|
||||
} else if (narrative_b) |_| {
|
||||
colors.printInfo(" {s} has narrative, {s} does not\n", .{ run_b, run_a });
|
||||
}
|
||||
}
|
||||
|
||||
// Metadata differences
|
||||
const meta_a = root_a.get("metadata");
|
||||
const meta_b = root_b.get("metadata");
|
||||
|
||||
if (meta_a) |ma| {
|
||||
if (meta_b) |mb| {
|
||||
if (ma == .object and mb == .object) {
|
||||
colors.printInfo("\n--- Metadata Differences ---\n", .{});
|
||||
try compareMetadata(ma.object, mb.object, run_a, run_b, all_fields);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics (if available)
|
||||
const metrics_a = root_a.get("metrics");
|
||||
const metrics_b = root_b.get("metrics");
|
||||
|
||||
if (metrics_a) |ma| {
|
||||
if (metrics_b) |mb| {
|
||||
if (ma == .object and mb == .object) {
|
||||
colors.printInfo("\n--- Metrics ---\n", .{});
|
||||
try compareMetrics(ma.object, mb.object, run_a, run_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Outcome
|
||||
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
|
||||
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
|
||||
if (outcome_a.len > 0 or outcome_b.len > 0) {
|
||||
colors.printInfo("\n--- Outcome ---\n", .{});
|
||||
if (std.mem.eql(u8, outcome_a, outcome_b)) {
|
||||
colors.printInfo(" Both: {s}\n", .{outcome_a});
|
||||
} else {
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, outcome_a });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, outcome_b });
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\n", .{});
|
||||
}
|
||||
|
||||
fn outputJsonComparison(
|
||||
allocator: std.mem.Allocator,
|
||||
root_a: std.json.ObjectMap,
|
||||
root_b: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
|
||||
try writer.writeAll("{\"run_a\":\"");
|
||||
try writer.writeAll(run_a);
|
||||
try writer.writeAll("\",\"run_b\":\"");
|
||||
try writer.writeAll(run_b);
|
||||
try writer.writeAll("\",\"differences\":");
|
||||
try writer.writeAll("{");
|
||||
|
||||
var first = true;
|
||||
|
||||
// Job names
|
||||
const job_name_a = jsonGetString(root_a, "job_name") orelse "";
|
||||
const job_name_b = jsonGetString(root_b, "job_name") orelse "";
|
||||
if (!std.mem.eql(u8, job_name_a, job_name_b)) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"job_name\":{\"a\":\"");
|
||||
try writer.writeAll(job_name_a);
|
||||
try writer.writeAll("\",\"b\":\"");
|
||||
try writer.writeAll(job_name_b);
|
||||
try writer.writeAll("\"}");
|
||||
}
|
||||
|
||||
// Experiment group
|
||||
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
|
||||
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
|
||||
if (!std.mem.eql(u8, group_a, group_b)) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"experiment_group\":{\"a\":\"");
|
||||
try writer.writeAll(group_a);
|
||||
try writer.writeAll("\",\"b\":\"");
|
||||
try writer.writeAll(group_b);
|
||||
try writer.writeAll("\"}");
|
||||
}
|
||||
|
||||
// Outcomes
|
||||
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
|
||||
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
|
||||
if (!std.mem.eql(u8, outcome_a, outcome_b)) {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"outcome\":{\"a\":\"");
|
||||
try writer.writeAll(outcome_a);
|
||||
try writer.writeAll("\",\"b\":\"");
|
||||
try writer.writeAll(outcome_b);
|
||||
try writer.writeAll("\"}");
|
||||
}
|
||||
|
||||
try writer.writeAll("}}");
|
||||
try writer.writeAll("}\n");
|
||||
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll(buf.items);
|
||||
}
|
||||
|
||||
fn compareNarrativeFields(
|
||||
na: std.json.ObjectMap,
|
||||
nb: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
const fields = [_][]const u8{ "hypothesis", "context", "intent", "expected_outcome" };
|
||||
|
||||
for (fields) |field| {
|
||||
const val_a = jsonGetString(na, field);
|
||||
const val_b = jsonGetString(nb, field);
|
||||
|
||||
if (val_a != null and val_b != null) {
|
||||
if (!std.mem.eql(u8, val_a.?, val_b.?)) {
|
||||
colors.printInfo(" {s}:\n", .{field});
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_a, val_a.? });
|
||||
colors.printInfo(" {s}: {s}\n", .{ run_b, val_b.? });
|
||||
}
|
||||
} else if (val_a != null) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ field, run_a });
|
||||
} else if (val_b != null) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ field, run_b });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compareMetadata(
|
||||
ma: std.json.ObjectMap,
|
||||
mb: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
show_all: bool,
|
||||
) !void {
|
||||
var has_differences = false;
|
||||
|
||||
// Compare key metadata fields
|
||||
const keys = [_][]const u8{ "batch_size", "learning_rate", "epochs", "model", "dataset" };
|
||||
|
||||
for (keys) |key| {
|
||||
if (ma.get(key)) |va| {
|
||||
if (mb.get(key)) |vb| {
|
||||
const str_a = jsonValueToString(va);
|
||||
const str_b = jsonValueToString(vb);
|
||||
|
||||
if (!std.mem.eql(u8, str_a, str_b)) {
|
||||
has_differences = true;
|
||||
colors.printInfo(" {s}: {s} → {s}\n", .{ key, str_a, str_b });
|
||||
} else if (show_all) {
|
||||
colors.printInfo(" {s}: {s} (same)\n", .{ key, str_a });
|
||||
}
|
||||
} else if (show_all) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ key, run_a });
|
||||
}
|
||||
} else if (mb.get(key)) |_| {
|
||||
if (show_all) {
|
||||
colors.printInfo(" {s}: only in {s}\n", .{ key, run_b });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_differences and !show_all) {
|
||||
colors.printInfo(" (no significant differences in common metadata)\n", .{});
|
||||
}
|
||||
}
|
||||
|
||||
fn compareMetrics(
|
||||
ma: std.json.ObjectMap,
|
||||
mb: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
_ = run_a;
|
||||
_ = run_b;
|
||||
|
||||
// Common metrics to compare
|
||||
const metrics = [_][]const u8{ "accuracy", "loss", "f1_score", "precision", "recall", "training_time", "validation_loss" };
|
||||
|
||||
for (metrics) |metric| {
|
||||
if (ma.get(metric)) |va| {
|
||||
if (mb.get(metric)) |vb| {
|
||||
const val_a = jsonValueToFloat(va);
|
||||
const val_b = jsonValueToFloat(vb);
|
||||
|
||||
const diff = val_b - val_a;
|
||||
const percent = if (val_a != 0) (diff / val_a) * 100 else 0;
|
||||
|
||||
const arrow = if (diff > 0) "↑" else if (diff < 0) "↓" else "=";
|
||||
|
||||
colors.printInfo(" {s}: {d:.4} → {d:.4} ({s}{d:.4}, {d:.1}%)\n", .{
|
||||
metric, val_a, val_b, arrow, @abs(diff), percent,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn outputCsvComparison(
|
||||
allocator: std.mem.Allocator,
|
||||
root_a: std.json.ObjectMap,
|
||||
root_b: std.json.ObjectMap,
|
||||
run_a: []const u8,
|
||||
run_b: []const u8,
|
||||
) !void {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
|
||||
// Header with actual run IDs as column names
|
||||
try writer.print("field,{s},{s},delta,notes\n", .{ run_a, run_b });
|
||||
|
||||
// Job names
|
||||
const job_name_a = jsonGetString(root_a, "job_name") orelse "";
|
||||
const job_name_b = jsonGetString(root_b, "job_name") orelse "";
|
||||
const job_same = std.mem.eql(u8, job_name_a, job_name_b);
|
||||
try writer.print("job_name,\"{s}\",\"{s}\",{s},\"{s}\"\n", .{
|
||||
job_name_a, job_name_b,
|
||||
if (job_same) "same" else "changed", if (job_same) "" else "different job names",
|
||||
});
|
||||
|
||||
// Outcomes
|
||||
const outcome_a = jsonGetString(root_a, "outcome") orelse "";
|
||||
const outcome_b = jsonGetString(root_b, "outcome") orelse "";
|
||||
const outcome_same = std.mem.eql(u8, outcome_a, outcome_b);
|
||||
try writer.print("outcome,{s},{s},{s},\"{s}\"\n", .{
|
||||
outcome_a, outcome_b,
|
||||
if (outcome_same) "same" else "changed", if (outcome_same) "" else "different outcomes",
|
||||
});
|
||||
|
||||
// Experiment group
|
||||
const group_a = jsonGetString(root_a, "experiment_group") orelse "";
|
||||
const group_b = jsonGetString(root_b, "experiment_group") orelse "";
|
||||
const group_same = std.mem.eql(u8, group_a, group_b);
|
||||
try writer.print("experiment_group,\"{s}\",\"{s}\",{s},\"{s}\"\n", .{
|
||||
group_a, group_b,
|
||||
if (group_same) "same" else "changed", if (group_same) "" else "different groups",
|
||||
});
|
||||
|
||||
// Metadata fields with delta calculation for numeric values
|
||||
const keys = [_][]const u8{ "batch_size", "learning_rate", "epochs", "model", "dataset" };
|
||||
for (keys) |key| {
|
||||
if (root_a.get(key)) |va| {
|
||||
if (root_b.get(key)) |vb| {
|
||||
const str_a = jsonValueToString(va);
|
||||
const str_b = jsonValueToString(vb);
|
||||
const same = std.mem.eql(u8, str_a, str_b);
|
||||
|
||||
// Try to calculate delta for numeric values
|
||||
const delta = if (!same) blk: {
|
||||
const f_a = jsonValueToFloat(va);
|
||||
const f_b = jsonValueToFloat(vb);
|
||||
if (f_a != 0 or f_b != 0) {
|
||||
break :blk try std.fmt.allocPrint(allocator, "{d:.4}", .{f_b - f_a});
|
||||
}
|
||||
break :blk "changed";
|
||||
} else "0";
|
||||
defer if (!same and (jsonValueToFloat(va) != 0 or jsonValueToFloat(vb) != 0)) allocator.free(delta);
|
||||
|
||||
try writer.print("{s},{s},{s},{s},\"{s}\"\n", .{
|
||||
key, str_a, str_b, delta,
|
||||
if (same) "same" else "changed",
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics with delta calculation
|
||||
const metrics = [_][]const u8{ "accuracy", "loss", "f1_score", "precision", "recall", "training_time" };
|
||||
for (metrics) |metric| {
|
||||
if (root_a.get(metric)) |va| {
|
||||
if (root_b.get(metric)) |vb| {
|
||||
const val_a = jsonValueToFloat(va);
|
||||
const val_b = jsonValueToFloat(vb);
|
||||
const diff = val_b - val_a;
|
||||
const percent = if (val_a != 0) (diff / val_a) * 100 else 0;
|
||||
|
||||
const notes = if (std.mem.eql(u8, metric, "loss") or std.mem.eql(u8, metric, "training_time"))
|
||||
if (diff < 0) "improved" else if (diff > 0) "degraded" else "same"
|
||||
else if (diff > 0) "improved" else if (diff < 0) "degraded" else "same";
|
||||
|
||||
try writer.print("{s},{d:.4},{d:.4},{d:.4},\"{d:.1}% - {s}\"\n", .{
|
||||
metric, val_a, val_b, diff, percent, notes,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll(buf.items);
|
||||
}
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) return null;
|
||||
const v = v_opt.?;
|
||||
if (v != .string) return null;
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn jsonValueToString(v: std.json.Value) []const u8 {
|
||||
return switch (v) {
|
||||
.string => |s| s,
|
||||
.integer => "number",
|
||||
.float => "number",
|
||||
.bool => |b| if (b) "true" else "false",
|
||||
else => "complex",
|
||||
};
|
||||
}
|
||||
|
||||
fn jsonValueToFloat(v: std.json.Value) f64 {
|
||||
return switch (v) {
|
||||
.float => |f| f,
|
||||
.integer => |i| @as(f64, @floatFromInt(i)),
|
||||
else => 0,
|
||||
};
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml compare <run-a> <run-b> [options]\n", .{});
|
||||
colors.printInfo("\nCompare two runs and show differences in:\n", .{});
|
||||
colors.printInfo(" - Job metadata (batch_size, learning_rate, etc.)\n", .{});
|
||||
colors.printInfo(" - Narrative fields (hypothesis, context, intent)\n", .{});
|
||||
colors.printInfo(" - Metrics (accuracy, loss, training_time)\n", .{});
|
||||
colors.printInfo(" - Outcome status\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --json Output as JSON\n", .{});
|
||||
colors.printInfo(" --all Show all fields (including unchanged)\n", .{});
|
||||
colors.printInfo(" --fields <csv> Compare only specific fields\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml compare run_abc run_def\n", .{});
|
||||
colors.printInfo(" ml compare run_abc run_def --json\n", .{});
|
||||
colors.printInfo(" ml compare run_abc run_def --all\n", .{});
|
||||
}
|
||||
|
|
@ -4,20 +4,26 @@ const ws = @import("../net/ws/client.zig");
|
|||
const colors = @import("../utils/colors.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const core = @import("../core.zig");
|
||||
const native_hash = @import("../native/hash.zig");
|
||||
|
||||
const DatasetOptions = struct {
|
||||
dry_run: bool = false,
|
||||
validate: bool = false,
|
||||
json: bool = false,
|
||||
csv: bool = false,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
var options = DatasetOptions{};
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var dry_run = false;
|
||||
var validate = false;
|
||||
var csv = false;
|
||||
|
||||
// Parse global flags: --dry-run, --validate, --json
|
||||
var positional = std.ArrayList([]const u8).initCapacity(allocator, args.len) catch |err| {
|
||||
return err;
|
||||
|
|
@ -26,57 +32,67 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsage();
|
||||
return;
|
||||
return printUsage();
|
||||
} else if (std.mem.eql(u8, arg, "--dry-run")) {
|
||||
options.dry_run = true;
|
||||
dry_run = true;
|
||||
} else if (std.mem.eql(u8, arg, "--validate")) {
|
||||
options.validate = true;
|
||||
validate = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--csv")) {
|
||||
csv = true;
|
||||
} else if (std.mem.startsWith(u8, arg, "--")) {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
core.output.errorMsg("dataset", "Unknown option");
|
||||
return printUsage();
|
||||
} else {
|
||||
try positional.append(allocator, arg);
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
const action = positional.items[0];
|
||||
|
||||
switch (positional.items.len) {
|
||||
0 => {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
return printUsage();
|
||||
},
|
||||
1 => {
|
||||
if (std.mem.eql(u8, action, "list")) {
|
||||
const options = DatasetOptions{ .json = flags.json, .csv = csv };
|
||||
try listDatasets(allocator, &options);
|
||||
return error.InvalidArgs;
|
||||
return;
|
||||
}
|
||||
},
|
||||
2 => {
|
||||
if (std.mem.eql(u8, action, "info")) {
|
||||
const options = DatasetOptions{ .json = flags.json, .csv = csv };
|
||||
try showDatasetInfo(allocator, positional.items[1], &options);
|
||||
return;
|
||||
} else if (std.mem.eql(u8, action, "search")) {
|
||||
const options = DatasetOptions{ .json = flags.json, .csv = csv };
|
||||
try searchDatasets(allocator, positional.items[1], &options);
|
||||
return error.InvalidArgs;
|
||||
return;
|
||||
} else if (std.mem.eql(u8, action, "verify")) {
|
||||
const options = DatasetOptions{ .json = flags.json, .validate = validate };
|
||||
try verifyDataset(allocator, positional.items[1], &options);
|
||||
return;
|
||||
}
|
||||
},
|
||||
3 => {
|
||||
if (std.mem.eql(u8, action, "register")) {
|
||||
const options = DatasetOptions{ .json = flags.json, .dry_run = dry_run };
|
||||
try registerDataset(allocator, positional.items[1], positional.items[2], &options);
|
||||
return error.InvalidArgs;
|
||||
return;
|
||||
}
|
||||
},
|
||||
else => {
|
||||
colors.printError("Unknoen action: {s}\n", .{action});
|
||||
printUsage();
|
||||
core.output.errorMsg("dataset", "Too many arguments");
|
||||
return error.InvalidArgs;
|
||||
},
|
||||
}
|
||||
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
fn printUsage() void {
|
||||
|
|
@ -86,6 +102,7 @@ fn printUsage() void {
|
|||
colors.printInfo(" register <name> <url> Register a dataset with URL\n", .{});
|
||||
colors.printInfo(" info <name> Show dataset information\n", .{});
|
||||
colors.printInfo(" search <term> Search datasets by name/description\n", .{});
|
||||
colors.printInfo(" verify <path|id> Verify dataset integrity (auto-hashes)\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --dry-run Show what would be requested\n", .{});
|
||||
colors.printInfo(" --validate Validate inputs only (no request)\n", .{});
|
||||
|
|
@ -362,6 +379,118 @@ fn searchDatasets(allocator: std.mem.Allocator, term: []const u8, options: *cons
|
|||
}
|
||||
}
|
||||
|
||||
fn verifyDataset(allocator: std.mem.Allocator, target: []const u8, options: *const DatasetOptions) !void {
|
||||
colors.printInfo("Verifying dataset: {s}\n", .{target});
|
||||
|
||||
const path = if (std.fs.path.isAbsolute(target))
|
||||
target
|
||||
else
|
||||
try std.fs.path.join(allocator, &[_][]const u8{ ".", target });
|
||||
defer if (!std.fs.path.isAbsolute(target)) allocator.free(path);
|
||||
|
||||
var dir = std.fs.openDirAbsolute(path, .{ .iterate = true }) catch {
|
||||
colors.printError("Dataset not found: {s}\n", .{target});
|
||||
return error.FileNotFound;
|
||||
};
|
||||
defer dir.close();
|
||||
|
||||
var file_count: usize = 0;
|
||||
var total_size: u64 = 0;
|
||||
|
||||
var walker = try dir.walk(allocator);
|
||||
defer walker.deinit();
|
||||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind != .file) continue;
|
||||
file_count += 1;
|
||||
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ path, entry.path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const stat = std.fs.cwd().statFile(full_path) catch continue;
|
||||
total_size += stat.size;
|
||||
}
|
||||
|
||||
// Compute native SHA256 hash
|
||||
const hash = blk: {
|
||||
break :blk native_hash.hashDirectory(allocator, path) catch |err| {
|
||||
colors.printWarning("Hash computation failed: {s}\n", .{@errorName(err)});
|
||||
// Continue without hash - verification still succeeded
|
||||
break :blk null;
|
||||
};
|
||||
};
|
||||
defer if (hash) |h| allocator.free(h);
|
||||
|
||||
if (options.json) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
var buffer: [4096]u8 = undefined;
|
||||
const hash_str = if (hash) |h| h else "null";
|
||||
const formatted = std.fmt.bufPrint(&buffer, "{{\"path\":\"{s}\",\"files\":{d},\"size\":{d},\"hash\":\"{s}\",\"ok\":true}}\n", .{
|
||||
target, file_count, total_size, hash_str,
|
||||
}) catch unreachable;
|
||||
try stdout_file.writeAll(formatted);
|
||||
} else if (options.csv) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll("metric,value\n");
|
||||
var buf: [256]u8 = undefined;
|
||||
const line1 = try std.fmt.bufPrint(&buf, "path,{s}\n", .{target});
|
||||
try stdout_file.writeAll(line1);
|
||||
const line2 = try std.fmt.bufPrint(&buf, "files,{d}\n", .{file_count});
|
||||
try stdout_file.writeAll(line2);
|
||||
const line3 = try std.fmt.bufPrint(&buf, "size_bytes,{d}\n", .{total_size});
|
||||
try stdout_file.writeAll(line3);
|
||||
const line4 = try std.fmt.bufPrint(&buf, "size_mb,{d:.2}\n", .{@as(f64, @floatFromInt(total_size)) / (1024 * 1024)});
|
||||
try stdout_file.writeAll(line4);
|
||||
if (hash) |h| {
|
||||
const line5 = try std.fmt.bufPrint(&buf, "sha256,{s}\n", .{h});
|
||||
try stdout_file.writeAll(line5);
|
||||
}
|
||||
} else {
|
||||
colors.printSuccess("✓ Dataset verified\n", .{});
|
||||
colors.printInfo(" Path: {s}\n", .{target});
|
||||
colors.printInfo(" Files: {d}\n", .{file_count});
|
||||
colors.printInfo(" Size: {d:.2} MB\n", .{@as(f64, @floatFromInt(total_size)) / (1024 * 1024)});
|
||||
if (hash) |h| {
|
||||
colors.printInfo(" SHA256: {s}\n", .{h});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hashDataset(allocator: std.mem.Allocator, path: []const u8) !void {
|
||||
colors.printInfo("Computing native SHA256 hash for: {s}\n", .{path});
|
||||
|
||||
// Check SIMD availability
|
||||
if (!native_hash.hasSimdSha256()) {
|
||||
colors.printWarning("SIMD SHA256 not available, using generic implementation\n", .{});
|
||||
} else {
|
||||
const impl_name = native_hash.getSimdImplName();
|
||||
colors.printInfo("Using {s} SHA256 implementation\n", .{impl_name});
|
||||
}
|
||||
|
||||
// Compute hash using native library
|
||||
const hash = native_hash.hashDirectory(allocator, path) catch |err| {
|
||||
switch (err) {
|
||||
error.ContextInitFailed => {
|
||||
colors.printError("Failed to initialize native hash context\n", .{});
|
||||
},
|
||||
error.HashFailed => {
|
||||
colors.printError("Hash computation failed\n", .{});
|
||||
},
|
||||
error.InvalidPath => {
|
||||
colors.printError("Invalid path: {s}\n", .{path});
|
||||
},
|
||||
error.OutOfMemory => {
|
||||
colors.printError("Out of memory\n", .{});
|
||||
},
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(hash);
|
||||
|
||||
// Print result
|
||||
colors.printSuccess("SHA256: {s}\n", .{hash});
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeByte('"');
|
||||
for (s) |c| {
|
||||
|
|
|
|||
53
cli/src/commands/dataset_hash.zig
Normal file
53
cli/src/commands/dataset_hash.zig
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
const std = @import("std");
|
||||
const cli = @import("../../main.zig");
|
||||
const native_hash = @import("../../native/hash.zig");
|
||||
const ui = @import("../../ui/ui.zig");
|
||||
const colors = @import("../../ui/colors.zig");
|
||||
|
||||
pub const name = "dataset hash";
|
||||
pub const description = "Hash a dataset directory using native SHA256 library";
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
// Parse arguments
|
||||
if (args.len < 1) {
|
||||
try ui.printHelp(name, description, &.{
|
||||
.{ "<path>", "Path to dataset directory" },
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const path = args[0];
|
||||
|
||||
// Check if native library is available
|
||||
if (!native_hash.hasSimdSha256()) {
|
||||
colors.printWarning("SIMD SHA256 not available, using generic implementation\n", .{});
|
||||
} else {
|
||||
const impl_name = native_hash.getSimdImplName();
|
||||
colors.printInfo("Using {s} SHA256 implementation\n", .{impl_name});
|
||||
}
|
||||
|
||||
// Hash the directory
|
||||
colors.printInfo("Hashing dataset at: {s}\n", .{path});
|
||||
|
||||
const hash = native_hash.hashDirectory(allocator, path) catch |err| {
|
||||
switch (err) {
|
||||
error.ContextInitFailed => {
|
||||
colors.printError("Failed to initialize native hash context\n", .{});
|
||||
},
|
||||
error.HashFailed => {
|
||||
colors.printError("Hash computation failed\n", .{});
|
||||
},
|
||||
error.InvalidPath => {
|
||||
colors.printError("Invalid path: {s}\n", .{path});
|
||||
},
|
||||
error.OutOfMemory => {
|
||||
colors.printError("Out of memory\n", .{});
|
||||
},
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(hash);
|
||||
|
||||
// Print result
|
||||
colors.printSuccess("Dataset hash: {s}\n", .{hash});
|
||||
}
|
||||
|
|
@ -1,644 +1,380 @@
|
|||
const std = @import("std");
|
||||
const config = @import("../config.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const history = @import("../utils/history.zig");
|
||||
const db = @import("../db.zig");
|
||||
const core = @import("../core.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const cancel_cmd = @import("cancel.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const uuid = @import("../utils/uuid.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
|
||||
fn jsonError(command: []const u8, message: []const u8) void {
|
||||
std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"{s}\",\"error\":\"{s}\"}}\n",
|
||||
.{ command, message },
|
||||
);
|
||||
}
|
||||
const ExperimentInfo = struct {
|
||||
id: []const u8,
|
||||
name: []const u8,
|
||||
description: []const u8,
|
||||
created_at: []const u8,
|
||||
status: []const u8,
|
||||
synced: bool,
|
||||
|
||||
fn jsonErrorWithDetails(command: []const u8, message: []const u8, details: []const u8) void {
|
||||
std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"{s}\",\"error\":\"{s}\",\"details\":\"{s}\"}}\n",
|
||||
.{ command, message, details },
|
||||
);
|
||||
}
|
||||
|
||||
const ExperimentOptions = struct {
|
||||
json: bool = false,
|
||||
help: bool = false,
|
||||
fn deinit(self: *ExperimentInfo, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.id);
|
||||
allocator.free(self.name);
|
||||
allocator.free(self.description);
|
||||
allocator.free(self.created_at);
|
||||
allocator.free(self.status);
|
||||
}
|
||||
};
|
||||
|
||||
/// Experiment command - manage experiments
|
||||
/// Usage:
|
||||
/// ml experiment create --name "baseline-cnn"
|
||||
/// ml experiment list
|
||||
/// ml experiment show <experiment_id>
|
||||
pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var options = ExperimentOptions{};
|
||||
var command_args = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
|
||||
return err;
|
||||
};
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var command_args = try core.flags.parseCommon(allocator, args, &flags);
|
||||
defer command_args.deinit(allocator);
|
||||
|
||||
// Parse flags
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
options.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
options.help = true;
|
||||
} else {
|
||||
try command_args.append(allocator, arg);
|
||||
}
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
if (flags.help or command_args.items.len == 0) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
if (command_args.items.len < 1 or options.help) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
const subcommand = command_args.items[0];
|
||||
const sub_args = if (command_args.items.len > 1) command_args.items[1..] else &[_][]const u8{};
|
||||
|
||||
const command = command_args.items[0];
|
||||
|
||||
if (std.mem.eql(u8, command, "init")) {
|
||||
try executeInit(allocator, command_args.items[1..], &options);
|
||||
} else if (std.mem.eql(u8, command, "log")) {
|
||||
try executeLog(allocator, command_args.items[1..], &options);
|
||||
} else if (std.mem.eql(u8, command, "show")) {
|
||||
try executeShow(allocator, command_args.items[1..], &options);
|
||||
} else if (std.mem.eql(u8, command, "list")) {
|
||||
try executeList(allocator, &options);
|
||||
} else if (std.mem.eql(u8, command, "delete")) {
|
||||
if (command_args.items.len < 2) {
|
||||
if (options.json) {
|
||||
jsonError("experiment.delete", "Usage: ml experiment delete <alias|commit>");
|
||||
} else {
|
||||
colors.printError("Usage: ml experiment delete <alias|commit>\n", .{});
|
||||
}
|
||||
return;
|
||||
}
|
||||
try executeDelete(allocator, command_args.items[1], &options);
|
||||
if (std.mem.eql(u8, subcommand, "create")) {
|
||||
return try createExperiment(allocator, sub_args, flags.json);
|
||||
} else if (std.mem.eql(u8, subcommand, "list")) {
|
||||
return try listExperiments(allocator, sub_args, flags.json);
|
||||
} else if (std.mem.eql(u8, subcommand, "show")) {
|
||||
return try showExperiment(allocator, sub_args, flags.json);
|
||||
} else {
|
||||
if (options.json) {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Unknown command: {s}", .{command});
|
||||
defer allocator.free(msg);
|
||||
jsonError("experiment", msg);
|
||||
} else {
|
||||
colors.printError("Unknown command: {s}\n", .{command});
|
||||
try printUsage();
|
||||
}
|
||||
const msg = try std.fmt.allocPrint(allocator, "Unknown subcommand: {s}", .{subcommand});
|
||||
defer allocator.free(msg);
|
||||
core.output.errorMsg("experiment", msg);
|
||||
return printUsage();
|
||||
}
|
||||
}
|
||||
|
||||
fn executeInit(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void {
|
||||
fn createExperiment(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
var name: ?[]const u8 = null;
|
||||
var description: ?[]const u8 = null;
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
if (std.mem.eql(u8, arg, "--name")) {
|
||||
if (i + 1 < args.len) {
|
||||
name = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--description")) {
|
||||
if (i + 1 < args.len) {
|
||||
description = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
|
||||
name = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--description") and i + 1 < args.len) {
|
||||
description = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Generate experiment ID and commit ID
|
||||
const stdcrypto = std.crypto;
|
||||
var exp_id_bytes: [16]u8 = undefined;
|
||||
stdcrypto.random.bytes(&exp_id_bytes);
|
||||
|
||||
var commit_id_bytes: [20]u8 = undefined;
|
||||
stdcrypto.random.bytes(&commit_id_bytes);
|
||||
|
||||
const exp_id = try crypto.encodeHexLower(allocator, &exp_id_bytes);
|
||||
defer allocator.free(exp_id);
|
||||
|
||||
const commit_id = try crypto.encodeHexLower(allocator, &commit_id_bytes);
|
||||
defer allocator.free(commit_id);
|
||||
|
||||
const exp_name = name orelse "unnamed-experiment";
|
||||
const exp_desc = description orelse "No description provided";
|
||||
|
||||
if (options.json) {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.init\",\"data\":{{\"experiment_id\":\"{s}\",\"commit_id\":\"{s}\",\"name\":\"{s}\",\"description\":\"{s}\",\"status\":\"initialized\"}}}}\n",
|
||||
.{ exp_id, commit_id, exp_name, exp_desc },
|
||||
);
|
||||
} else {
|
||||
colors.printInfo("Experiment initialized successfully!\n", .{});
|
||||
colors.printInfo("Experiment ID: {s}\n", .{exp_id});
|
||||
colors.printInfo("Commit ID: {s}\n", .{commit_id});
|
||||
colors.printInfo("Name: {s}\n", .{exp_name});
|
||||
colors.printInfo("Description: {s}\n", .{exp_desc});
|
||||
colors.printInfo("Status: initialized\n", .{});
|
||||
colors.printInfo("Use this commit ID when queuing jobs: --commit-id {s}\n", .{commit_id});
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml experiment [options] <command> [args]\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --json Output structured JSON\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help message\n", .{});
|
||||
colors.printInfo("\nCommands:\n", .{});
|
||||
colors.printInfo(" init Initialize a new experiment\n", .{});
|
||||
colors.printInfo(" log Log a metric for an experiment\n", .{});
|
||||
colors.printInfo(" show <commit_id> Show experiment details\n", .{});
|
||||
colors.printInfo(" list List recent experiments\n", .{});
|
||||
colors.printInfo(" delete <alias|commit> Cancel/delete an experiment\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml experiment init --name \"my-experiment\" --description \"Test experiment\"\n", .{});
|
||||
colors.printInfo(" ml experiment show abc123 --json\n", .{});
|
||||
colors.printInfo(" ml experiment list --json\n", .{});
|
||||
}
|
||||
|
||||
fn executeLog(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void {
|
||||
var commit_id: ?[]const u8 = null;
|
||||
var name: ?[]const u8 = null;
|
||||
var value: ?f64 = null;
|
||||
var step: u32 = 0;
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
if (std.mem.eql(u8, arg, "--id")) {
|
||||
if (i + 1 < args.len) {
|
||||
commit_id = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--name")) {
|
||||
if (i + 1 < args.len) {
|
||||
name = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--value")) {
|
||||
if (i + 1 < args.len) {
|
||||
value = try std.fmt.parseFloat(f64, args[i + 1]);
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--step")) {
|
||||
if (i + 1 < args.len) {
|
||||
step = try std.fmt.parseInt(u32, args[i + 1], 10);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
if (name == null) {
|
||||
core.output.errorMsg("experiment", "--name is required");
|
||||
return error.MissingArgument;
|
||||
}
|
||||
|
||||
if (commit_id == null or name == null or value == null) {
|
||||
if (options.json) {
|
||||
jsonError("experiment.log", "Usage: ml experiment log --id <commit_id> --name <name> --value <value> [--step <step>]");
|
||||
} else {
|
||||
colors.printError("Usage: ml experiment log --id <commit_id> --name <name> --value <value> [--step <step>]\n", .{});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const Config = @import("../config.zig").Config;
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
// Check mode
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
// Local mode: create in SQLite
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
try client.sendLogMetric(api_key_hash, commit_id.?, name.?, value.?, step);
|
||||
// TODO: Add synced column to schema - required for server sync
|
||||
const sql = "INSERT INTO ml_experiments (experiment_id, name, description, status, synced) VALUES (?, ?, ?, 'active', 0);";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
if (options.json) {
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
const exp_id = try generateExperimentID(allocator);
|
||||
defer allocator.free(exp_id);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.log\",\"data\":{{\"commit_id\":\"{s}\",\"metric\":{{\"name\":\"{s}\",\"value\":{d},\"step\":{d}}},\"message\":\"{s}\"}}}}\n",
|
||||
.{ commit_id.?, name.?, value.?, step, message },
|
||||
);
|
||||
return;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
try db.DB.bindText(stmt, 1, exp_id);
|
||||
try db.DB.bindText(stmt, 2, name.?);
|
||||
try db.DB.bindText(stmt, 3, description orelse "");
|
||||
_ = try db.DB.step(stmt);
|
||||
|
||||
switch (packet.packet_type) {
|
||||
.success => {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.log\",\"data\":{{\"commit_id\":\"{s}\",\"metric\":{{\"name\":\"{s}\",\"value\":{d},\"step\":{d}}},\"message\":\"{s}\"}}}}\n",
|
||||
.{ commit_id.?, name.?, value.?, step, message },
|
||||
);
|
||||
return;
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
} else {
|
||||
try client.receiveAndHandleResponse(allocator, "Log metric");
|
||||
colors.printSuccess("Metric logged successfully!\n", .{});
|
||||
colors.printInfo("Commit ID: {s}\n", .{commit_id.?});
|
||||
colors.printInfo("Metric: {s} = {d:.4} (step {d})\n", .{ name.?, value.?, step });
|
||||
}
|
||||
}
|
||||
|
||||
fn executeShow(allocator: std.mem.Allocator, args: []const []const u8, options: *const ExperimentOptions) !void {
|
||||
if (args.len < 1) {
|
||||
if (options.json) {
|
||||
jsonError("experiment.show", "Usage: ml experiment show <commit_id|alias>");
|
||||
} else {
|
||||
colors.printError("Usage: ml experiment show <commit_id|alias>\n", .{});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const identifier = args[0];
|
||||
const commit_id = try resolveCommitIdentifier(allocator, identifier);
|
||||
defer allocator.free(commit_id);
|
||||
|
||||
const Config = @import("../config.zig").Config;
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
// Update config with new experiment
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
if (mut_cfg.experiment == null) {
|
||||
mut_cfg.experiment = config.ExperimentConfig{
|
||||
.name = "",
|
||||
.entrypoint = "",
|
||||
};
|
||||
}
|
||||
mut_cfg.experiment.?.name = try allocator.dupe(u8, name.?);
|
||||
try mut_cfg.save(allocator);
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
database.checkpointOnExit();
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendGetExperiment(api_key_hash, commit_id);
|
||||
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
const packet = try protocol.ResponsePacket.deserialize(message, allocator);
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
// For now, let's just print the result
|
||||
switch (packet.packet_type) {
|
||||
.success, .data => {
|
||||
if (packet.data_payload) |payload| {
|
||||
if (options.json) {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.show\",\"data\":{s}}}\n",
|
||||
.{payload},
|
||||
);
|
||||
return;
|
||||
} else {
|
||||
// Parse JSON response
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch |err| {
|
||||
colors.printError("Failed to parse response: {}\n", .{err});
|
||||
return;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value;
|
||||
if (root != .object) {
|
||||
colors.printError("Invalid response format\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const metadata = root.object.get("metadata");
|
||||
const metrics = root.object.get("metrics");
|
||||
|
||||
if (metadata != null and metadata.? == .object) {
|
||||
colors.printInfo("\nExperiment Details:\n", .{});
|
||||
colors.printInfo("-------------------\n", .{});
|
||||
const m = metadata.?.object;
|
||||
if (m.get("JobName")) |v| colors.printInfo("Job Name: {s}\n", .{v.string});
|
||||
if (m.get("CommitID")) |v| colors.printInfo("Commit ID: {s}\n", .{v.string});
|
||||
if (m.get("User")) |v| colors.printInfo("User: {s}\n", .{v.string});
|
||||
if (m.get("Timestamp")) |v| {
|
||||
const ts = v.integer;
|
||||
colors.printInfo("Timestamp: {d}\n", .{ts});
|
||||
}
|
||||
}
|
||||
|
||||
if (metrics != null and metrics.? == .array) {
|
||||
colors.printInfo("\nMetrics:\n", .{});
|
||||
colors.printInfo("-------------------\n", .{});
|
||||
const items = metrics.?.array.items;
|
||||
if (items.len == 0) {
|
||||
colors.printInfo("No metrics logged.\n", .{});
|
||||
} else {
|
||||
for (items) |item| {
|
||||
if (item == .object) {
|
||||
const name = item.object.get("name").?.string;
|
||||
const value = item.object.get("value").?.float;
|
||||
const step = item.object.get("step").?.integer;
|
||||
colors.printInfo("{s}: {d:.4} (Step: {d})\n", .{ name, value, step });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const repro = root.object.get("reproducibility");
|
||||
if (repro != null and repro.? == .object) {
|
||||
colors.printInfo("\nReproducibility:\n", .{});
|
||||
colors.printInfo("-------------------\n", .{});
|
||||
|
||||
const repro_obj = repro.?.object;
|
||||
if (repro_obj.get("experiment")) |exp_val| {
|
||||
if (exp_val == .object) {
|
||||
const e = exp_val.object;
|
||||
if (e.get("id")) |v| colors.printInfo("Experiment ID: {s}\n", .{v.string});
|
||||
if (e.get("name")) |v| colors.printInfo("Name: {s}\n", .{v.string});
|
||||
if (e.get("status")) |v| colors.printInfo("Status: {s}\n", .{v.string});
|
||||
if (e.get("user_id")) |v| colors.printInfo("User ID: {s}\n", .{v.string});
|
||||
}
|
||||
}
|
||||
|
||||
if (repro_obj.get("environment")) |env_val| {
|
||||
if (env_val == .object) {
|
||||
const env = env_val.object;
|
||||
if (env.get("python_version")) |v| colors.printInfo("Python: {s}\n", .{v.string});
|
||||
if (env.get("cuda_version")) |v| colors.printInfo("CUDA: {s}\n", .{v.string});
|
||||
if (env.get("system_os")) |v| colors.printInfo("OS: {s}\n", .{v.string});
|
||||
if (env.get("system_arch")) |v| colors.printInfo("Arch: {s}\n", .{v.string});
|
||||
if (env.get("hostname")) |v| colors.printInfo("Hostname: {s}\n", .{v.string});
|
||||
if (env.get("requirements_hash")) |v| colors.printInfo("Requirements hash: {s}\n", .{v.string});
|
||||
}
|
||||
}
|
||||
|
||||
if (repro_obj.get("git_info")) |git_val| {
|
||||
if (git_val == .object) {
|
||||
const g = git_val.object;
|
||||
if (g.get("commit_sha")) |v| colors.printInfo("Git SHA: {s}\n", .{v.string});
|
||||
if (g.get("branch")) |v| colors.printInfo("Git branch: {s}\n", .{v.string});
|
||||
if (g.get("remote_url")) |v| colors.printInfo("Git remote: {s}\n", .{v.string});
|
||||
if (g.get("is_dirty")) |v| colors.printInfo("Git dirty: {}\n", .{v.bool});
|
||||
}
|
||||
}
|
||||
|
||||
if (repro_obj.get("seeds")) |seeds_val| {
|
||||
if (seeds_val == .object) {
|
||||
const s = seeds_val.object;
|
||||
if (s.get("numpy_seed")) |v| colors.printInfo("Numpy seed: {d}\n", .{v.integer});
|
||||
if (s.get("torch_seed")) |v| colors.printInfo("Torch seed: {d}\n", .{v.integer});
|
||||
if (s.get("tensorflow_seed")) |v| colors.printInfo("TensorFlow seed: {d}\n", .{v.integer});
|
||||
if (s.get("random_seed")) |v| colors.printInfo("Random seed: {d}\n", .{v.integer});
|
||||
}
|
||||
}
|
||||
}
|
||||
colors.printInfo("\n", .{});
|
||||
}
|
||||
} else if (packet.success_message) |msg| {
|
||||
if (options.json) {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.show\",\"data\":{{\"message\":\"{s}\"}}}}\n",
|
||||
.{msg},
|
||||
);
|
||||
} else {
|
||||
colors.printSuccess("{s}\n", .{msg});
|
||||
}
|
||||
}
|
||||
},
|
||||
.error_packet => {
|
||||
const code_int: u8 = if (packet.error_code) |c| @intFromEnum(c) else 0;
|
||||
const default_msg = if (packet.error_code) |c| protocol.ResponsePacket.getErrorMessage(c) else "Server error";
|
||||
const err_msg = packet.error_message orelse default_msg;
|
||||
const details = packet.error_details orelse "";
|
||||
if (options.json) {
|
||||
std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"experiment.show\",\"error\":{s},\"error_code\":{d},\"error_details\":{s}}}\n",
|
||||
.{ err_msg, code_int, details },
|
||||
);
|
||||
} else {
|
||||
colors.printError("Error: {s}\n", .{err_msg});
|
||||
if (details.len > 0) {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
}
|
||||
}
|
||||
},
|
||||
else => {
|
||||
if (options.json) {
|
||||
jsonError("experiment.show", "Unexpected response type");
|
||||
} else {
|
||||
colors.printError("Unexpected response type\n", .{});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn executeList(allocator: std.mem.Allocator, options: *const ExperimentOptions) !void {
|
||||
const entries = history.loadEntries(allocator) catch |err| {
|
||||
if (options.json) {
|
||||
const details = try std.fmt.allocPrint(allocator, "{}", .{err});
|
||||
defer allocator.free(details);
|
||||
jsonErrorWithDetails("experiment.list", "Failed to read experiment history", details);
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"experiment_id\":\"{s}\",\"name\":\"{s}\"}}\n", .{ exp_id, name.? });
|
||||
} else {
|
||||
colors.printError("Failed to read experiment history: {}\n", .{err});
|
||||
colors.printSuccess("✓ Created experiment: {s} ({s})\n", .{ name.?, exp_id[0..8] });
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer history.freeEntries(allocator, entries);
|
||||
|
||||
if (entries.len == 0) {
|
||||
if (options.json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"experiment.list\",\"data\":{{\"experiments\":[],\"total\":0,\"message\":\"No experiments recorded yet. Use `ml queue` to submit one.\"}}}}\n", .{});
|
||||
} else {
|
||||
colors.printWarning("No experiments recorded yet. Use `ml queue` to submit one.\n", .{});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (options.json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"experiment.list\",\"data\":{{\"experiments\":[", .{});
|
||||
var idx: usize = 0;
|
||||
while (idx < entries.len) : (idx += 1) {
|
||||
const entry = entries[entries.len - idx - 1];
|
||||
if (idx > 0) {
|
||||
std.debug.print(",", .{});
|
||||
}
|
||||
std.debug.print(
|
||||
"{{\"alias\":\"{s}\",\"commit_id\":\"{s}\",\"queued_at\":{d}}}",
|
||||
.{
|
||||
entry.job_name, entry.commit_id,
|
||||
entry.queued_at,
|
||||
},
|
||||
);
|
||||
}
|
||||
std.debug.print("],\"total\":{d}", .{entries.len});
|
||||
std.debug.print("}}}}\n", .{});
|
||||
} else {
|
||||
colors.printInfo("\nRecent Experiments (latest first):\n", .{});
|
||||
colors.printInfo("---------------------------------\n", .{});
|
||||
|
||||
const max_display = if (entries.len > 20) 20 else entries.len;
|
||||
var idx: usize = 0;
|
||||
while (idx < max_display) : (idx += 1) {
|
||||
const entry = entries[entries.len - idx - 1];
|
||||
std.debug.print("{d:2}) Alias: {s}\n", .{ idx + 1, entry.job_name });
|
||||
std.debug.print(" Commit: {s}\n", .{entry.commit_id});
|
||||
std.debug.print(" Queued: {d}\n\n", .{entry.queued_at});
|
||||
}
|
||||
|
||||
if (entries.len > max_display) {
|
||||
colors.printInfo("...and {d} more\n", .{entries.len - max_display});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn executeDelete(allocator: std.mem.Allocator, identifier: []const u8, options: *const ExperimentOptions) !void {
|
||||
const resolved = try resolveJobIdentifier(allocator, identifier);
|
||||
defer allocator.free(resolved);
|
||||
|
||||
if (options.json) {
|
||||
const Config = @import("../config.zig").Config;
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Server mode: send to server via WebSocket
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendCancelJob(resolved, api_key_hash);
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
try client.sendCreateExperiment(api_key_hash, name.?, description orelse "");
|
||||
|
||||
// Prefer parsing structured binary response packets if present.
|
||||
if (message.len > 0) {
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch null;
|
||||
if (packet) |p| {
|
||||
defer p.deinit(allocator);
|
||||
// Receive response
|
||||
const response = try client.receiveMessage(allocator);
|
||||
defer allocator.free(response);
|
||||
|
||||
switch (p.packet_type) {
|
||||
.success => {
|
||||
const msg = p.success_message orelse "";
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.delete\",\"data\":{{\"experiment\":\"{s}\",\"message\":\"{s}\"}}}}\n",
|
||||
.{ resolved, msg },
|
||||
);
|
||||
return;
|
||||
},
|
||||
.error_packet => {
|
||||
const code_int: u8 = if (p.error_code) |c| @intFromEnum(c) else 0;
|
||||
const default_msg = if (p.error_code) |c| protocol.ResponsePacket.getErrorMessage(c) else "Server error";
|
||||
const err_msg = p.error_message orelse default_msg;
|
||||
const details = p.error_details orelse "";
|
||||
std.debug.print("{{\"success\":false,\"command\":\"experiment.delete\",\"error\":\"{s}\",\"error_code\":{d},\"error_details\":\"{s}\",\"data\":{{\"experiment\":\"{s}\"}}}}\n", .{ err_msg, code_int, details, resolved });
|
||||
return error.CommandFailed;
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
// Parse response (expecting JSON with experiment_id)
|
||||
if (std.mem.indexOf(u8, response, "experiment_id") != null) {
|
||||
// Also update local config
|
||||
var mut_cfg = cfg;
|
||||
if (mut_cfg.experiment == null) {
|
||||
mut_cfg.experiment = config.ExperimentConfig{
|
||||
.name = "",
|
||||
.entrypoint = "",
|
||||
};
|
||||
}
|
||||
mut_cfg.experiment.?.name = try allocator.dupe(u8, name.?);
|
||||
try mut_cfg.save(allocator);
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"name\":\"{s}\",\"source\":\"server\"}}\n", .{name.?});
|
||||
} else {
|
||||
colors.printSuccess("✓ Created experiment on server: {s}\n", .{name.?});
|
||||
}
|
||||
} else {
|
||||
colors.printError("Failed to create experiment on server: {s}\n", .{response});
|
||||
return error.ServerError;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn listExperiments(allocator: std.mem.Allocator, _: []const []const u8, json: bool) !void {
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
// Local mode: list from SQLite
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
const sql = "SELECT experiment_id, name, description, created_at, status, synced FROM ml_experiments ORDER BY created_at DESC;";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
var experiments = try std.ArrayList(ExperimentInfo).initCapacity(allocator, 16);
|
||||
defer {
|
||||
for (experiments.items) |*e| e.deinit(allocator);
|
||||
experiments.deinit(allocator);
|
||||
}
|
||||
|
||||
// Next: if server returned JSON, wrap it and attempt to infer success.
|
||||
if (message.len > 0 and message[0] == '{') {
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.delete\",\"data\":{{\"experiment\":\"{s}\",\"response\":{s}}}}}\n",
|
||||
.{ resolved, message },
|
||||
);
|
||||
return;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
while (try db.DB.step(stmt)) {
|
||||
try experiments.append(allocator, ExperimentInfo{
|
||||
.id = try allocator.dupe(u8, db.DB.columnText(stmt, 0)),
|
||||
.name = try allocator.dupe(u8, db.DB.columnText(stmt, 1)),
|
||||
.description = try allocator.dupe(u8, db.DB.columnText(stmt, 2)),
|
||||
.created_at = try allocator.dupe(u8, db.DB.columnText(stmt, 3)),
|
||||
.status = try allocator.dupe(u8, db.DB.columnText(stmt, 4)),
|
||||
.synced = db.DB.columnInt64(stmt, 5) != 0,
|
||||
});
|
||||
}
|
||||
|
||||
if (parsed.value == .object) {
|
||||
if (parsed.value.object.get("success")) |sval| {
|
||||
if (sval == .bool and !sval.bool) {
|
||||
const err_val = parsed.value.object.get("error");
|
||||
const err_msg = if (err_val != null and err_val.? == .string) err_val.?.string else "Failed to cancel experiment";
|
||||
std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"experiment.delete\",\"error\":\"{s}\",\"data\":{{\"experiment\":\"{s}\",\"response\":{s}}}}}\n",
|
||||
.{ err_msg, resolved, message },
|
||||
);
|
||||
return error.CommandFailed;
|
||||
if (json) {
|
||||
std.debug.print("[", .{});
|
||||
for (experiments.items, 0..) |e, i| {
|
||||
if (i > 0) std.debug.print(",", .{});
|
||||
std.debug.print("{{\"id\":\"{s}\",\"name\":\"{s}\",\"status\":\"{s}\",\"description\":\"{s}\",\"synced\":{s}}}", .{ e.id, e.name, e.status, e.description, if (e.synced) "true" else "false" });
|
||||
}
|
||||
std.debug.print("]\n", .{});
|
||||
} else {
|
||||
if (experiments.items.len == 0) {
|
||||
colors.printInfo("No experiments found.\n", .{});
|
||||
} else {
|
||||
colors.printInfo("Experiments:\n", .{});
|
||||
for (experiments.items) |e| {
|
||||
const sync_indicator = if (e.synced) "✓" else "↑";
|
||||
std.debug.print(" {s} {s} {s} ({s})\n", .{ sync_indicator, e.id[0..8], e.name, e.status });
|
||||
if (e.description.len > 0) {
|
||||
std.debug.print(" {s}\n", .{e.description});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.delete\",\"data\":{{\"experiment\":\"{s}\",\"response\":{s}}}}}\n",
|
||||
.{ resolved, message },
|
||||
);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// Server mode: query server via WebSocket
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Fallback: plain string message.
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.delete\",\"data\":{{\"experiment\":\"{s}\",\"message\":\"{s}\"}}}}\n",
|
||||
.{ resolved, message },
|
||||
);
|
||||
return;
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendListExperiments(api_key_hash);
|
||||
|
||||
// Receive response
|
||||
const response = try client.receiveMessage(allocator);
|
||||
defer allocator.free(response);
|
||||
|
||||
// For now, just display raw response
|
||||
if (json) {
|
||||
std.debug.print("{s}\n", .{response});
|
||||
} else {
|
||||
colors.printInfo("Experiments from server:\n", .{});
|
||||
std.debug.print("{s}\n", .{response});
|
||||
}
|
||||
}
|
||||
|
||||
// Build cancel args with JSON flag if needed
|
||||
var cancel_args = std.ArrayList([]const u8).initCapacity(allocator, 5) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer cancel_args.deinit(allocator);
|
||||
|
||||
try cancel_args.append(allocator, resolved);
|
||||
|
||||
cancel_cmd.run(allocator, cancel_args.items) catch |err| {
|
||||
colors.printError("Failed to cancel experiment '{s}': {}\n", .{ resolved, err });
|
||||
return err;
|
||||
};
|
||||
}
|
||||
|
||||
fn resolveCommitIdentifier(allocator: std.mem.Allocator, identifier: []const u8) ![]const u8 {
|
||||
const entries = history.loadEntries(allocator) catch {
|
||||
if (identifier.len != 40) return error.InvalidCommitId;
|
||||
const commit_bytes = try crypto.decodeHex(allocator, identifier);
|
||||
if (commit_bytes.len != 20) {
|
||||
allocator.free(commit_bytes);
|
||||
return error.InvalidCommitId;
|
||||
}
|
||||
return commit_bytes;
|
||||
};
|
||||
defer history.freeEntries(allocator, entries);
|
||||
|
||||
var commit_hex: []const u8 = identifier;
|
||||
for (entries) |entry| {
|
||||
if (std.mem.eql(u8, identifier, entry.job_name)) {
|
||||
commit_hex = entry.commit_id;
|
||||
break;
|
||||
}
|
||||
fn showExperiment(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
if (args.len == 0) {
|
||||
core.output.errorMsg("experiment", "experiment_id required");
|
||||
return error.MissingArgument;
|
||||
}
|
||||
|
||||
if (commit_hex.len != 40) return error.InvalidCommitId;
|
||||
const commit_bytes = try crypto.decodeHex(allocator, commit_hex);
|
||||
if (commit_bytes.len != 20) {
|
||||
allocator.free(commit_bytes);
|
||||
return error.InvalidCommitId;
|
||||
const exp_id = args[0];
|
||||
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
// Local mode: show from SQLite
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
// Get experiment details
|
||||
const exp_sql = "SELECT experiment_id, name, description, created_at, status, synced FROM ml_experiments WHERE experiment_id = ?;";
|
||||
const exp_stmt = try database.prepare(exp_sql);
|
||||
defer db.DB.finalize(exp_stmt);
|
||||
try db.DB.bindText(exp_stmt, 1, exp_id);
|
||||
|
||||
if (!try db.DB.step(exp_stmt)) {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Experiment not found: {s}", .{exp_id});
|
||||
defer allocator.free(msg);
|
||||
core.output.errorMsg("experiment", msg);
|
||||
return error.NotFound;
|
||||
}
|
||||
|
||||
const name = db.DB.columnText(exp_stmt, 1);
|
||||
const description = db.DB.columnText(exp_stmt, 2);
|
||||
const created_at = db.DB.columnText(exp_stmt, 3);
|
||||
const status = db.DB.columnText(exp_stmt, 4);
|
||||
const synced = db.DB.columnInt64(exp_stmt, 5) != 0;
|
||||
|
||||
// Get run count and last run date
|
||||
const runs_sql =
|
||||
"SELECT COUNT(*), MAX(start_time) FROM ml_runs WHERE experiment_id = ?;";
|
||||
const runs_stmt = try database.prepare(runs_sql);
|
||||
defer db.DB.finalize(runs_stmt);
|
||||
try db.DB.bindText(runs_stmt, 1, exp_id);
|
||||
|
||||
var run_count: i64 = 0;
|
||||
var last_run: ?[]const u8 = null;
|
||||
if (try db.DB.step(runs_stmt)) {
|
||||
run_count = db.DB.columnInt64(runs_stmt, 0);
|
||||
if (db.DB.columnText(runs_stmt, 1).len > 0) {
|
||||
last_run = try allocator.dupe(u8, db.DB.columnText(runs_stmt, 1));
|
||||
}
|
||||
}
|
||||
defer if (last_run) |lr| allocator.free(lr);
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{{\"experiment_id\":\"{s}\",\"name\":\"{s}\",\"description\":\"{s}\",\"status\":\"{s}\",\"created_at\":\"{s}\",\"synced\":{s},\"run_count\":{d},\"last_run\":\"{s}\"}}\n", .{
|
||||
exp_id, name, description, status, created_at,
|
||||
if (synced) "true" else "false", run_count, last_run orelse "null",
|
||||
});
|
||||
} else {
|
||||
colors.printInfo("Experiment: {s}\n", .{name});
|
||||
std.debug.print(" ID: {s}\n", .{exp_id});
|
||||
std.debug.print(" Status: {s}\n", .{status});
|
||||
if (description.len > 0) {
|
||||
std.debug.print(" Description: {s}\n", .{description});
|
||||
}
|
||||
std.debug.print(" Created: {s}\n", .{created_at});
|
||||
std.debug.print(" Synced: {s}\n", .{if (synced) "✓" else "↑ pending"});
|
||||
std.debug.print(" Runs: {d}\n", .{run_count});
|
||||
if (last_run) |lr| {
|
||||
std.debug.print(" Last run: {s}\n", .{lr});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Server mode: query server via WebSocket
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendGetExperimentByID(api_key_hash, exp_id);
|
||||
|
||||
// Receive response
|
||||
const response = try client.receiveMessage(allocator);
|
||||
defer allocator.free(response);
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{s}\n", .{response});
|
||||
} else {
|
||||
colors.printInfo("Experiment details from server:\n", .{});
|
||||
std.debug.print("{s}\n", .{response});
|
||||
}
|
||||
}
|
||||
return commit_bytes;
|
||||
}
|
||||
|
||||
fn resolveJobIdentifier(allocator: std.mem.Allocator, identifier: []const u8) ![]const u8 {
|
||||
const entries = history.loadEntries(allocator) catch {
|
||||
return allocator.dupe(u8, identifier);
|
||||
};
|
||||
defer history.freeEntries(allocator, entries);
|
||||
|
||||
for (entries) |entry| {
|
||||
if (std.mem.eql(u8, identifier, entry.job_name) or
|
||||
std.mem.eql(u8, identifier, entry.commit_id) or
|
||||
(identifier.len <= entry.commit_id.len and
|
||||
std.mem.eql(u8, entry.commit_id[0..identifier.len], identifier)))
|
||||
{
|
||||
return allocator.dupe(u8, entry.job_name);
|
||||
}
|
||||
}
|
||||
|
||||
return allocator.dupe(u8, identifier);
|
||||
fn generateExperimentID(allocator: std.mem.Allocator) ![]const u8 {
|
||||
return try uuid.generateV4(allocator);
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
std.debug.print("Usage: ml experiment <subcommand> [options]\n\n", .{});
|
||||
std.debug.print("Subcommands:\n", .{});
|
||||
std.debug.print(" create --name <name> [--description <desc>] Create new experiment\n", .{});
|
||||
std.debug.print(" list List experiments\n", .{});
|
||||
std.debug.print(" show <experiment_id> Show experiment details\n", .{});
|
||||
std.debug.print("\nOptions:\n", .{});
|
||||
std.debug.print(" --name <string> Experiment name (required for create)\n", .{});
|
||||
std.debug.print(" --description <string> Experiment description\n", .{});
|
||||
std.debug.print(" --help, -h Show this help\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n\n", .{});
|
||||
std.debug.print("Examples:\n", .{});
|
||||
std.debug.print(" ml experiment create --name \"baseline-cnn\"\n", .{});
|
||||
std.debug.print(" ml experiment list\n", .{});
|
||||
}
|
||||
|
|
|
|||
348
cli/src/commands/export_cmd.zig
Normal file
348
cli/src/commands/export_cmd.zig
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub const ExportOptions = struct {
|
||||
anonymize: bool = false,
|
||||
anonymize_level: []const u8 = "metadata-only", // metadata-only, full
|
||||
bundle: ?[]const u8 = null,
|
||||
base_override: ?[]const u8 = null,
|
||||
json: bool = false,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
const target = argv[0];
|
||||
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var anonymize = false;
|
||||
var anonymize_level: []const u8 = "metadata-only";
|
||||
var bundle: ?[]const u8 = null;
|
||||
var base_override: ?[]const u8 = null;
|
||||
|
||||
var i: usize = 1;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const arg = argv[i];
|
||||
if (std.mem.eql(u8, arg, "--anonymize")) {
|
||||
anonymize = true;
|
||||
} else if (std.mem.eql(u8, arg, "--anonymize-level") and i + 1 < argv.len) {
|
||||
anonymize_level = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--bundle") and i + 1 < argv.len) {
|
||||
bundle = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--base") and i + 1 < argv.len) {
|
||||
base_override = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
return printUsage();
|
||||
} else {
|
||||
core.output.errorMsg("export", "Unknown option");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
// Validate anonymize level
|
||||
if (!std.mem.eql(u8, anonymize_level, "metadata-only") and
|
||||
!std.mem.eql(u8, anonymize_level, "full"))
|
||||
{
|
||||
core.output.errorMsg("export", "Invalid anonymize level");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (flags.json) {
|
||||
var stdout_writer = io.stdoutWriter();
|
||||
try stdout_writer.print("{{\"success\":true,\"anonymize_level\":\"{s}\"}}\n", .{anonymize_level});
|
||||
} else {
|
||||
colors.printInfo("Anonymization level: {s}\n", .{anonymize_level});
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const resolved_base = base_override orelse cfg.worker_base;
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'.\n",
|
||||
.{target},
|
||||
);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
// Read the manifest
|
||||
const manifest_content = manifest.readFileAlloc(allocator, manifest_path) catch |err| {
|
||||
colors.printError("Failed to read manifest: {}\n", .{err});
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_content);
|
||||
|
||||
// Parse the manifest
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, manifest_content, .{}) catch |err| {
|
||||
colors.printError("Failed to parse manifest: {}\n", .{err});
|
||||
return err;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
// Anonymize if requested
|
||||
var final_content: []u8 = undefined;
|
||||
var final_content_owned = false;
|
||||
|
||||
if (anonymize) {
|
||||
final_content = try anonymizeManifest(allocator, parsed.value, anonymize_level);
|
||||
final_content_owned = true;
|
||||
} else {
|
||||
final_content = manifest_content;
|
||||
}
|
||||
defer if (final_content_owned) allocator.free(final_content);
|
||||
|
||||
// Output or bundle
|
||||
if (bundle) |bundle_path| {
|
||||
// Create a simple tar-like bundle (just the manifest for now)
|
||||
// In production, this would include code, configs, etc.
|
||||
var bundle_file = try std.fs.cwd().createFile(bundle_path, .{});
|
||||
defer bundle_file.close();
|
||||
|
||||
try bundle_file.writeAll(final_content);
|
||||
|
||||
if (flags.json) {
|
||||
var stdout_writer = io.stdoutWriter();
|
||||
try stdout_writer.print("{{\"success\":true,\"bundle\":\"{s}\",\"anonymized\":{}}}\n", .{
|
||||
bundle_path,
|
||||
anonymize,
|
||||
});
|
||||
} else {
|
||||
colors.printSuccess("✓ Exported to {s}\n", .{bundle_path});
|
||||
if (anonymize) {
|
||||
colors.printInfo(" Anonymization level: {s}\n", .{anonymize_level});
|
||||
colors.printInfo(" Paths redacted, IPs removed, usernames anonymized\n", .{});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Output to stdout
|
||||
var stdout_writer = io.stdoutWriter();
|
||||
try stdout_writer.print("{s}\n", .{final_content});
|
||||
}
|
||||
}
|
||||
|
||||
fn anonymizeManifest(
|
||||
allocator: std.mem.Allocator,
|
||||
root: std.json.Value,
|
||||
level: []const u8,
|
||||
) ![]u8 {
|
||||
// Clone the value by stringifying and re-parsing so we can modify it
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
try writeJSONValue(buf.writer(allocator), root);
|
||||
const json_str = try buf.toOwnedSlice(allocator);
|
||||
defer allocator.free(json_str);
|
||||
var parsed_clone = try std.json.parseFromSlice(std.json.Value, allocator, json_str, .{});
|
||||
defer parsed_clone.deinit();
|
||||
var cloned = parsed_clone.value;
|
||||
|
||||
if (cloned != .object) {
|
||||
// For non-objects, just re-serialize and return
|
||||
var out_buf = std.ArrayList(u8).empty;
|
||||
defer out_buf.deinit(allocator);
|
||||
try writeJSONValue(out_buf.writer(allocator), cloned);
|
||||
return out_buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
const obj = &cloned.object;
|
||||
|
||||
// Anonymize metadata fields
|
||||
if (obj.get("metadata")) |meta| {
|
||||
if (meta == .object) {
|
||||
var meta_obj = meta.object;
|
||||
|
||||
// Path anonymization: /nas/private/user/data → /datasets/data
|
||||
if (meta_obj.get("dataset_path")) |dp| {
|
||||
if (dp == .string) {
|
||||
const anon_path = try anonymizePath(allocator, dp.string);
|
||||
defer allocator.free(anon_path);
|
||||
try meta_obj.put("dataset_path", std.json.Value{ .string = anon_path });
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymize other paths if full level
|
||||
if (std.mem.eql(u8, level, "full")) {
|
||||
const path_fields = [_][]const u8{ "code_path", "output_path", "checkpoint_path" };
|
||||
for (path_fields) |field| {
|
||||
if (meta_obj.get(field)) |p| {
|
||||
if (p == .string) {
|
||||
const anon_path = try anonymizePath(allocator, p.string);
|
||||
defer allocator.free(anon_path);
|
||||
try meta_obj.put(field, std.json.Value{ .string = anon_path });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymize system info
|
||||
if (obj.get("system")) |sys| {
|
||||
if (sys == .object) {
|
||||
var sys_obj = sys.object;
|
||||
|
||||
// Hostname: gpu-server-01.internal → worker-A
|
||||
if (sys_obj.get("hostname")) |h| {
|
||||
if (h == .string) {
|
||||
try sys_obj.put("hostname", std.json.Value{ .string = "worker-A" });
|
||||
}
|
||||
}
|
||||
|
||||
// IP addresses → [REDACTED]
|
||||
if (sys_obj.get("ip_address")) |_| {
|
||||
try sys_obj.put("ip_address", std.json.Value{ .string = "[REDACTED]" });
|
||||
}
|
||||
|
||||
// Username: user@lab.edu → researcher-N
|
||||
if (sys_obj.get("username")) |_| {
|
||||
try sys_obj.put("username", std.json.Value{ .string = "[REDACTED]" });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Anonymize logs reference (logs may contain PII)
|
||||
if (std.mem.eql(u8, level, "full")) {
|
||||
_ = obj.swapRemove("logs");
|
||||
_ = obj.swapRemove("log_path");
|
||||
_ = obj.swapRemove("annotations");
|
||||
}
|
||||
|
||||
// Serialize back to JSON
|
||||
var out_buf = std.ArrayList(u8).empty;
|
||||
defer out_buf.deinit(allocator);
|
||||
try writeJSONValue(out_buf.writer(allocator), cloned);
|
||||
return out_buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn anonymizePath(allocator: std.mem.Allocator, path: []const u8) ![]const u8 {
|
||||
// Simple path anonymization: replace leading path components with generic names
|
||||
// /home/user/project/data → /workspace/data
|
||||
// /nas/private/lab/experiments → /datasets/experiments
|
||||
|
||||
// Find the last component
|
||||
const last_sep = std.mem.lastIndexOf(u8, path, "/");
|
||||
if (last_sep == null) return allocator.dupe(u8, path);
|
||||
|
||||
const filename = path[last_sep.? + 1 ..];
|
||||
|
||||
// Determine prefix based on context
|
||||
const prefix = if (std.mem.indexOf(u8, path, "data") != null)
|
||||
"/datasets"
|
||||
else if (std.mem.indexOf(u8, path, "model") != null or std.mem.indexOf(u8, path, "checkpoint") != null)
|
||||
"/models"
|
||||
else if (std.mem.indexOf(u8, path, "code") != null or std.mem.indexOf(u8, path, "src") != null)
|
||||
"/code"
|
||||
else
|
||||
"/workspace";
|
||||
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ prefix, filename });
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml export <run-id|path> [options]\n", .{});
|
||||
colors.printInfo("\nExport experiment for sharing or archiving:\n", .{});
|
||||
colors.printInfo(" --bundle <path> Create tarball at path\n", .{});
|
||||
colors.printInfo(" --anonymize Enable anonymization\n", .{});
|
||||
colors.printInfo(" --anonymize-level <lvl> 'metadata-only' or 'full'\n", .{});
|
||||
colors.printInfo(" --base <path> Base path to find run\n", .{});
|
||||
colors.printInfo(" --json Output JSON response\n", .{});
|
||||
colors.printInfo("\nAnonymization rules:\n", .{});
|
||||
colors.printInfo(" - Paths: /nas/private/... → /datasets/...\n", .{});
|
||||
colors.printInfo(" - Hostnames: gpu-server-01 → worker-A\n", .{});
|
||||
colors.printInfo(" - IPs: 192.168.1.100 → [REDACTED]\n", .{});
|
||||
colors.printInfo(" - Usernames: user@lab.edu → [REDACTED]\n", .{});
|
||||
colors.printInfo(" - Full level: Also removes logs and annotations\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml export run_abc --bundle run_abc.tar.gz\n", .{});
|
||||
colors.printInfo(" ml export run_abc --bundle run_abc.tar.gz --anonymize\n", .{});
|
||||
colors.printInfo(" ml export run_abc --anonymize --anonymize-level full\n", .{});
|
||||
}
|
||||
|
||||
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
|
||||
switch (v) {
|
||||
.null => try writer.writeAll("null"),
|
||||
.bool => |b| try writer.print("{}", .{b}),
|
||||
.integer => |i| try writer.print("{d}", .{i}),
|
||||
.float => |f| try writer.print("{d}", .{f}),
|
||||
.string => |s| try writeJSONString(writer, s),
|
||||
.array => |arr| {
|
||||
try writer.writeAll("[");
|
||||
for (arr.items, 0..) |item, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONValue(writer, item);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
},
|
||||
.object => |obj| {
|
||||
try writer.writeAll("{");
|
||||
var first = true;
|
||||
var it = obj.iterator();
|
||||
while (it.next()) |entry| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.print("\"{s}\":", .{entry.key_ptr.*});
|
||||
try writeJSONValue(writer, entry.value_ptr.*);
|
||||
}
|
||||
try writer.writeAll("}");
|
||||
},
|
||||
.number_string => |s| try writer.print("{s}", .{s}),
|
||||
}
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeAll("\"");
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
var buf: [6]u8 = undefined;
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = '0';
|
||||
buf[3] = '0';
|
||||
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
||||
buf[5] = hexDigit(@intCast(c & 0x0F));
|
||||
try writer.writeAll(&buf);
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
507
cli/src/commands/find.zig
Normal file
507
cli/src/commands/find.zig
Normal file
|
|
@ -0,0 +1,507 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub const FindOptions = struct {
|
||||
json: bool = false,
|
||||
csv: bool = false,
|
||||
limit: usize = 20,
|
||||
tag: ?[]const u8 = null,
|
||||
outcome: ?[]const u8 = null,
|
||||
dataset: ?[]const u8 = null,
|
||||
experiment_group: ?[]const u8 = null,
|
||||
author: ?[]const u8 = null,
|
||||
after: ?[]const u8 = null,
|
||||
before: ?[]const u8 = null,
|
||||
query: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var limit: usize = 20;
|
||||
var csv: bool = false;
|
||||
var tag: ?[]const u8 = null;
|
||||
var outcome: ?[]const u8 = null;
|
||||
var dataset: ?[]const u8 = null;
|
||||
var experiment_group: ?[]const u8 = null;
|
||||
var author: ?[]const u8 = null;
|
||||
var after: ?[]const u8 = null;
|
||||
var before: ?[]const u8 = null;
|
||||
var query_str: ?[]const u8 = null;
|
||||
|
||||
// First argument might be a query string or a flag
|
||||
var arg_idx: usize = 0;
|
||||
if (!std.mem.startsWith(u8, argv[0], "--")) {
|
||||
query_str = argv[0];
|
||||
arg_idx = 1;
|
||||
}
|
||||
|
||||
var i: usize = arg_idx;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const arg = argv[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--csv")) {
|
||||
csv = true;
|
||||
} else if (std.mem.eql(u8, arg, "--limit") and i + 1 < argv.len) {
|
||||
limit = try std.fmt.parseInt(usize, argv[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--tag") and i + 1 < argv.len) {
|
||||
tag = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--outcome") and i + 1 < argv.len) {
|
||||
outcome = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--dataset") and i + 1 < argv.len) {
|
||||
dataset = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--group") and i + 1 < argv.len) {
|
||||
experiment_group = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--author") and i + 1 < argv.len) {
|
||||
author = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--after") and i + 1 < argv.len) {
|
||||
after = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--before") and i + 1 < argv.len) {
|
||||
before = argv[i + 1];
|
||||
i += 1;
|
||||
} else {
|
||||
core.output.errorMsg("find", "Unknown option");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
colors.printInfo("Searching experiments...\n", .{});
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Build search options struct for JSON builder
|
||||
const search_options = FindOptions{
|
||||
.json = flags.json,
|
||||
.csv = csv,
|
||||
.limit = limit,
|
||||
.tag = tag,
|
||||
.outcome = outcome,
|
||||
.dataset = dataset,
|
||||
.experiment_group = experiment_group,
|
||||
.author = author,
|
||||
.after = after,
|
||||
.before = before,
|
||||
.query = query_str,
|
||||
};
|
||||
|
||||
const search_json = try buildSearchJson(allocator, &search_options);
|
||||
defer allocator.free(search_json);
|
||||
|
||||
// Send search request - we'll use the dataset search opcode as a placeholder
|
||||
// In production, this would have a dedicated search endpoint
|
||||
try client.sendDatasetSearch(search_json, api_key_hash);
|
||||
|
||||
const msg = try client.receiveMessage(allocator);
|
||||
defer allocator.free(msg);
|
||||
|
||||
// Parse response
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, msg, .{}) catch {
|
||||
if (flags.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{{\"error\":\"invalid_response\"}}\n", .{});
|
||||
} else {
|
||||
colors.printError("Failed to parse search results\n", .{});
|
||||
}
|
||||
return error.InvalidResponse;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value;
|
||||
|
||||
if (flags.json) {
|
||||
try io.stdoutWriteJson(root);
|
||||
} else if (csv) {
|
||||
const options = FindOptions{ .json = flags.json, .csv = csv };
|
||||
try outputCsvResults(allocator, root, &options);
|
||||
} else {
|
||||
const options = FindOptions{ .json = flags.json, .csv = csv };
|
||||
try outputHumanResults(root, &options);
|
||||
}
|
||||
}
|
||||
|
||||
fn buildSearchJson(allocator: std.mem.Allocator, options: *const FindOptions) ![]u8 {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{");
|
||||
|
||||
var first = true;
|
||||
|
||||
if (options.query) |q| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"query\":");
|
||||
try writeJSONString(writer, q);
|
||||
}
|
||||
|
||||
if (options.tag) |t| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"tag\":");
|
||||
try writeJSONString(writer, t);
|
||||
}
|
||||
|
||||
if (options.outcome) |o| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"outcome\":");
|
||||
try writeJSONString(writer, o);
|
||||
}
|
||||
|
||||
if (options.dataset) |d| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"dataset\":");
|
||||
try writeJSONString(writer, d);
|
||||
}
|
||||
|
||||
if (options.experiment_group) |eg| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"experiment_group\":");
|
||||
try writeJSONString(writer, eg);
|
||||
}
|
||||
|
||||
if (options.author) |a| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"author\":");
|
||||
try writeJSONString(writer, a);
|
||||
}
|
||||
|
||||
if (options.after) |a| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"after\":");
|
||||
try writeJSONString(writer, a);
|
||||
}
|
||||
|
||||
if (options.before) |b| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"before\":");
|
||||
try writeJSONString(writer, b);
|
||||
}
|
||||
|
||||
if (!first) try writer.writeAll(",");
|
||||
try writer.print("\"limit\":{d}", .{options.limit});
|
||||
|
||||
try writer.writeAll("}");
|
||||
|
||||
return buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn outputHumanResults(root: std.json.Value, options: *const FindOptions) !void {
|
||||
if (root != .object) {
|
||||
colors.printError("Invalid response format\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const obj = root.object;
|
||||
|
||||
// Check for error
|
||||
if (obj.get("error")) |err| {
|
||||
if (err == .string) {
|
||||
colors.printError("Search error: {s}\n", .{err.string});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
|
||||
if (results == null) {
|
||||
colors.printInfo("No results found\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
if (results.? != .array) {
|
||||
colors.printError("Invalid results format\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const items = results.?.array.items;
|
||||
|
||||
if (items.len == 0) {
|
||||
colors.printInfo("No experiments found matching your criteria\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
colors.printSuccess("Found {d} experiment(s)\n\n", .{items.len});
|
||||
|
||||
// Print header
|
||||
colors.printInfo("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
"ID", "Job Name", "Outcome", "Status", "Group/Tags",
|
||||
});
|
||||
colors.printInfo("{s}\n", .{"────────────────────────────────────────────────────────────────────────────────"});
|
||||
|
||||
for (items) |item| {
|
||||
if (item != .object) continue;
|
||||
const run_obj = item.object;
|
||||
|
||||
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
|
||||
const short_id = if (id.len > 8) id[0..8] else id;
|
||||
|
||||
const job_name = jsonGetString(run_obj, "job_name") orelse "unnamed";
|
||||
const job_display = if (job_name.len > 18) job_name[0..18] else job_name;
|
||||
|
||||
const outcome = jsonGetString(run_obj, "outcome") orelse "-";
|
||||
const status = jsonGetString(run_obj, "status") orelse "unknown";
|
||||
|
||||
// Build group/tags summary
|
||||
var summary_buf: [30]u8 = undefined;
|
||||
const summary = blk: {
|
||||
const group = jsonGetString(run_obj, "experiment_group");
|
||||
const tags = run_obj.get("tags");
|
||||
|
||||
if (group) |g| {
|
||||
if (tags) |t| {
|
||||
if (t == .string) {
|
||||
break :blk std.fmt.bufPrint(&summary_buf, "{s}/{s}", .{ g[0..@min(g.len, 10)], t.string[0..@min(t.string.len, 10)] }) catch g[0..@min(g.len, 15)];
|
||||
}
|
||||
}
|
||||
break :blk g[0..@min(g.len, 20)];
|
||||
}
|
||||
break :blk "-";
|
||||
};
|
||||
|
||||
// Color code by outcome
|
||||
if (std.mem.eql(u8, outcome, "validates")) {
|
||||
colors.printSuccess("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
} else if (std.mem.eql(u8, outcome, "refutes")) {
|
||||
colors.printError("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
} else if (std.mem.eql(u8, outcome, "partial") or std.mem.eql(u8, outcome, "inconclusive")) {
|
||||
colors.printWarning("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
} else {
|
||||
colors.printInfo("{s:12} {s:20} {s:15} {s:10} {s}\n", .{
|
||||
short_id, job_display, outcome, status, summary,
|
||||
});
|
||||
}
|
||||
|
||||
// Show hypothesis if available and query matches
|
||||
if (options.query) |_| {
|
||||
if (run_obj.get("narrative")) |narr| {
|
||||
if (narr == .object) {
|
||||
if (narr.object.get("hypothesis")) |h| {
|
||||
if (h == .string and h.string.len > 0) {
|
||||
const hypo = h.string;
|
||||
const display = if (hypo.len > 50) hypo[0..50] else hypo;
|
||||
colors.printInfo(" ↳ {s}...\n", .{display});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\nUse 'ml info <id>' for details, 'ml compare <a> <b>' to compare runs\n", .{});
|
||||
}
|
||||
|
||||
fn outputCsvResults(allocator: std.mem.Allocator, root: std.json.Value, options: *const FindOptions) !void {
|
||||
_ = options;
|
||||
|
||||
if (root != .object) {
|
||||
colors.printError("Invalid response format\n", .{});
|
||||
return;
|
||||
}
|
||||
|
||||
const obj = root.object;
|
||||
|
||||
// Check for error
|
||||
if (obj.get("error")) |err| {
|
||||
if (err == .string) {
|
||||
colors.printError("Search error: {s}\n", .{err.string});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const results = obj.get("results") orelse obj.get("experiments") orelse obj.get("runs");
|
||||
if (results == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (results.? != .array) {
|
||||
return;
|
||||
}
|
||||
|
||||
const items = results.?.array.items;
|
||||
|
||||
// CSV Header
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll("id,job_name,outcome,status,experiment_group,tags,hypothesis\n");
|
||||
|
||||
for (items) |item| {
|
||||
if (item != .object) continue;
|
||||
const run_obj = item.object;
|
||||
|
||||
const id = jsonGetString(run_obj, "id") orelse jsonGetString(run_obj, "run_id") orelse "unknown";
|
||||
const job_name = jsonGetString(run_obj, "job_name") orelse "";
|
||||
const outcome = jsonGetString(run_obj, "outcome") orelse "";
|
||||
const status = jsonGetString(run_obj, "status") orelse "";
|
||||
const group = jsonGetString(run_obj, "experiment_group") orelse "";
|
||||
|
||||
// Get tags
|
||||
var tags: []const u8 = "";
|
||||
if (run_obj.get("tags")) |t| {
|
||||
if (t == .string) tags = t.string;
|
||||
}
|
||||
|
||||
// Get hypothesis
|
||||
var hypothesis: []const u8 = "";
|
||||
if (run_obj.get("narrative")) |narr| {
|
||||
if (narr == .object) {
|
||||
if (narr.object.get("hypothesis")) |h| {
|
||||
if (h == .string) hypothesis = h.string;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Escape fields that might contain commas or quotes
|
||||
const safe_job = try escapeCsv(allocator, job_name);
|
||||
defer allocator.free(safe_job);
|
||||
const safe_group = try escapeCsv(allocator, group);
|
||||
defer allocator.free(safe_group);
|
||||
const safe_tags = try escapeCsv(allocator, tags);
|
||||
defer allocator.free(safe_tags);
|
||||
const safe_hypo = try escapeCsv(allocator, hypothesis);
|
||||
defer allocator.free(safe_hypo);
|
||||
|
||||
var buf: [1024]u8 = undefined;
|
||||
const line = try std.fmt.bufPrint(&buf, "{s},{s},{s},{s},{s},{s},{s}\n", .{
|
||||
id, safe_job, outcome, status, safe_group, safe_tags, safe_hypo,
|
||||
});
|
||||
try stdout_file.writeAll(line);
|
||||
}
|
||||
}
|
||||
|
||||
fn escapeCsv(allocator: std.mem.Allocator, s: []const u8) ![]u8 {
|
||||
// Check if we need to escape (contains comma, quote, or newline)
|
||||
var needs_escape = false;
|
||||
for (s) |c| {
|
||||
if (c == ',' or c == '"' or c == '\n' or c == '\r') {
|
||||
needs_escape = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!needs_escape) {
|
||||
return allocator.dupe(u8, s);
|
||||
}
|
||||
|
||||
// Escape: wrap in quotes and double existing quotes
|
||||
var buf = std.ArrayList(u8).initCapacity(allocator, s.len + 2) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
try buf.append(allocator, '"');
|
||||
for (s) |c| {
|
||||
if (c == '"') {
|
||||
try buf.appendSlice(allocator, "\"\"");
|
||||
} else {
|
||||
try buf.append(allocator, c);
|
||||
}
|
||||
}
|
||||
try buf.append(allocator, '"');
|
||||
|
||||
return buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn jsonGetString(obj: std.json.ObjectMap, key: []const u8) ?[]const u8 {
|
||||
const v_opt = obj.get(key);
|
||||
if (v_opt == null) return null;
|
||||
const v = v_opt.?;
|
||||
if (v != .string) return null;
|
||||
return v.string;
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeAll("\"");
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
var buf: [6]u8 = undefined;
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = '0';
|
||||
buf[3] = '0';
|
||||
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
||||
buf[5] = hexDigit(@intCast(c & 0x0F));
|
||||
try writer.writeAll(&buf);
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml find [query] [options]\n", .{});
|
||||
colors.printInfo("\nSearch experiments by:\n", .{});
|
||||
colors.printInfo(" Query (free text): ml find \"hypothesis: warmup\"\n", .{});
|
||||
colors.printInfo(" Tags: ml find --tag ablation\n", .{});
|
||||
colors.printInfo(" Outcome: ml find --outcome validates\n", .{});
|
||||
colors.printInfo(" Dataset: ml find --dataset imagenet\n", .{});
|
||||
colors.printInfo(" Experiment group: ml find --experiment-group lr-scaling\n", .{});
|
||||
colors.printInfo(" Author: ml find --author user@lab.edu\n", .{});
|
||||
colors.printInfo(" Time range: ml find --after 2024-01-01 --before 2024-03-01\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --limit <n> Max results (default: 20)\n", .{});
|
||||
colors.printInfo(" --json Output as JSON\n", .{});
|
||||
colors.printInfo(" --csv Output as CSV\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml find --tag ablation --outcome validates\n", .{});
|
||||
colors.printInfo(" ml find --experiment-group batch-scaling --json\n", .{});
|
||||
colors.printInfo(" ml find \"learning rate\" --after 2024-01-01\n", .{});
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ const Config = @import("../config.zig").Config;
|
|||
const io = @import("../utils/io.zig");
|
||||
const json = @import("../utils/json.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub const Options = struct {
|
||||
json: bool = false,
|
||||
|
|
@ -11,59 +12,49 @@ pub const Options = struct {
|
|||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
var opts = Options{};
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var base: ?[]const u8 = null;
|
||||
var target_path: ?[]const u8 = null;
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
opts.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--base")) {
|
||||
if (i + 1 >= args.len) {
|
||||
colors.printError("Missing value for --base\n", .{});
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
opts.base = args[i + 1];
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--base") and i + 1 < args.len) {
|
||||
base = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.startsWith(u8, arg, "--help")) {
|
||||
try printUsage();
|
||||
return;
|
||||
return printUsage();
|
||||
} else if (std.mem.startsWith(u8, arg, "--")) {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
try printUsage();
|
||||
core.output.errorMsg("info", "Unknown option");
|
||||
return error.InvalidArgs;
|
||||
} else {
|
||||
target_path = arg;
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
if (target_path == null) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
core.output.errorMsg("info", "No target path specified");
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target_path.?, opts.base) catch |err| {
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target_path.?, base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
|
||||
.{target_path.?},
|
||||
);
|
||||
core.output.errorMsgDetailed("info", "Manifest not found", "Provide a path or use --base <path>");
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const data = try manifest.readFileAlloc(allocator, manifest_path);
|
||||
defer allocator.free(data);
|
||||
defer {
|
||||
allocator.free(manifest_path);
|
||||
allocator.free(data);
|
||||
}
|
||||
|
||||
if (opts.json) {
|
||||
if (flags.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{data});
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -1,23 +1,109 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const db = @import("../db.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub fn run(_: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len > 0 and (std.mem.eql(u8, args[0], "--help") or std.mem.eql(u8, args[0], "-h"))) {
|
||||
printUsage();
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var remaining = try core.flags.parseCommon(allocator, args, &flags);
|
||||
defer remaining.deinit(allocator);
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
// Handle help flag early
|
||||
if (flags.help) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
// Parse CLI-specific overrides and flags
|
||||
const cli_tracking_uri = core.flags.parseKVFlag(remaining.items, "tracking-uri");
|
||||
const cli_artifact_path = core.flags.parseKVFlag(remaining.items, "artifact-path");
|
||||
const cli_sync_uri = core.flags.parseKVFlag(remaining.items, "sync-uri");
|
||||
const force_local = core.flags.parseBoolFlag(remaining.items, "local");
|
||||
|
||||
var cfg = try Config.loadWithOverrides(allocator, cli_tracking_uri, cli_artifact_path, cli_sync_uri);
|
||||
defer cfg.deinit(allocator);
|
||||
|
||||
// Print resolved config
|
||||
std.debug.print("Resolved config:\n", .{});
|
||||
std.debug.print(" tracking_uri = {s}", .{cfg.tracking_uri});
|
||||
|
||||
// Indicate if using default
|
||||
if (cli_tracking_uri == null and std.mem.eql(u8, cfg.tracking_uri, "sqlite://./fetch_ml.db")) {
|
||||
std.debug.print(" (default)\n", .{});
|
||||
} else {
|
||||
std.debug.print("\n", .{});
|
||||
}
|
||||
|
||||
std.debug.print(" artifact_path = {s}", .{cfg.artifact_path});
|
||||
if (cli_artifact_path == null and std.mem.eql(u8, cfg.artifact_path, "./experiments/")) {
|
||||
std.debug.print(" (default)\n", .{});
|
||||
} else {
|
||||
std.debug.print("\n", .{});
|
||||
}
|
||||
std.debug.print(" sync_uri = {s}\n", .{if (cfg.sync_uri.len > 0) cfg.sync_uri else "(not set)"});
|
||||
std.debug.print("\n", .{});
|
||||
|
||||
// Default path: create config only (no DB speculatively)
|
||||
if (!force_local) {
|
||||
std.debug.print("✓ Created .fetchml/config.toml\n", .{});
|
||||
std.debug.print(" Local tracking DB will be created automatically if server becomes unavailable.\n", .{});
|
||||
|
||||
if (cfg.sync_uri.len > 0) {
|
||||
std.debug.print(" Server: {s}:{d}\n", .{ cfg.worker_host, cfg.worker_port });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std.debug.print("ML Experiment Manager - Configuration Setup\n\n", .{});
|
||||
std.debug.print("Please create ~/.ml/config.toml with the following format:\n\n", .{});
|
||||
std.debug.print("worker_host = \"worker.local\"\n", .{});
|
||||
std.debug.print("worker_user = \"mluser\"\n", .{});
|
||||
std.debug.print("worker_base = \"/data/ml-experiments\"\n", .{});
|
||||
std.debug.print("worker_port = 22\n", .{});
|
||||
std.debug.print("api_key = \"your-api-key\"\n", .{});
|
||||
std.debug.print("\n[OK] Configuration template shown above\n", .{});
|
||||
// --local path: create config + DB now
|
||||
std.debug.print("(local mode explicitly requested)\n\n", .{});
|
||||
|
||||
// Get DB path from tracking URI
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
// Check if DB already exists
|
||||
const db_exists = blk: {
|
||||
std.fs.accessAbsolute(db_path, .{}) catch |err| {
|
||||
if (err == error.FileNotFound) break :blk false;
|
||||
};
|
||||
break :blk true;
|
||||
};
|
||||
|
||||
if (db_exists) {
|
||||
std.debug.print("✓ Database already exists: {s}\n", .{db_path});
|
||||
} else {
|
||||
// Create parent directories if needed
|
||||
if (std.fs.path.dirname(db_path)) |dir| {
|
||||
std.fs.makeDirAbsolute(dir) catch |err| {
|
||||
if (err != error.PathAlreadyExists) {
|
||||
std.log.err("Failed to create directory {s}: {}", .{ dir, err });
|
||||
return error.MkdirFailed;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Initialize database (creates schema)
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
defer database.checkpointOnExit();
|
||||
|
||||
std.debug.print("✓ Created database: {s}\n", .{db_path});
|
||||
}
|
||||
|
||||
std.debug.print("✓ Created .fetchml/config.toml\n", .{});
|
||||
std.debug.print("✓ Schema applied (WAL mode enabled)\n", .{});
|
||||
std.debug.print(" fetch_ml.db-wal and fetch_ml.db-shm will appear during use — expected.\n", .{});
|
||||
std.debug.print(" The DB is just a file. Delete it freely — recreated automatically on next run.\n", .{});
|
||||
}
|
||||
|
||||
fn printUsage() void {
|
||||
std.debug.print("Usage: ml init\n\n", .{});
|
||||
std.debug.print("Shows a template for ~/.ml/config.toml\n", .{});
|
||||
std.debug.print("Usage: ml init [OPTIONS]\n\n", .{});
|
||||
std.debug.print("Initialize FetchML configuration\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --local Create local database now (default: config only)\n", .{});
|
||||
std.debug.print(" --tracking-uri URI SQLite database path (e.g., sqlite://./fetch_ml.db)\n", .{});
|
||||
std.debug.print(" --artifact-path PATH Artifacts directory (default: ./experiments/)\n", .{});
|
||||
std.debug.print(" --sync-uri URI Server to sync with (e.g., wss://ml.company.com/ws)\n", .{});
|
||||
std.debug.print(" -h, --help Show this help\n", .{});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ const ws = @import("../net/ws/client.zig");
|
|||
const protocol = @import("../net/protocol.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const core = @import("../core.zig");
|
||||
|
||||
const blocked_packages = [_][]const u8{ "requests", "urllib3", "httpx", "aiohttp", "socket", "telnetlib" };
|
||||
|
||||
|
|
@ -23,9 +24,10 @@ fn validatePackageName(name: []const u8) bool {
|
|||
return true;
|
||||
}
|
||||
|
||||
fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
_ = json;
|
||||
if (args.len < 1) {
|
||||
colors.printError("Usage: ml jupyter restore <name>\n", .{});
|
||||
core.output.errorMsg("jupyter.restore", "Usage: ml jupyter restore <name>");
|
||||
return;
|
||||
}
|
||||
const name = args[0];
|
||||
|
|
@ -48,10 +50,10 @@ fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void
|
|||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
colors.printInfo("Restoring workspace {s}...\n", .{name});
|
||||
core.output.info("Restoring workspace {s}...", .{name});
|
||||
|
||||
client.sendRestoreJupyter(name, api_key_hash) catch |err| {
|
||||
colors.printError("Failed to send restore command: {}\n", .{err});
|
||||
core.output.errorMsgDetailed("jupyter.restore", "Failed to send restore command", @errorName(err));
|
||||
return;
|
||||
};
|
||||
|
||||
|
|
@ -70,22 +72,17 @@ fn restoreJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void
|
|||
switch (packet.packet_type) {
|
||||
.success => {
|
||||
if (packet.success_message) |msg| {
|
||||
colors.printSuccess("{s}\n", .{msg});
|
||||
core.output.info("{s}", .{msg});
|
||||
} else {
|
||||
colors.printSuccess("Workspace restored.\n", .{});
|
||||
core.output.info("Workspace restored.", .{});
|
||||
}
|
||||
},
|
||||
.error_packet => {
|
||||
const error_msg = protocol.ResponsePacket.getErrorMessage(packet.error_code.?);
|
||||
colors.printError("Failed to restore workspace: {s}\n", .{error_msg});
|
||||
if (packet.error_details) |details| {
|
||||
colors.printError("Details: {s}\n", .{details});
|
||||
} else if (packet.error_message) |msg| {
|
||||
colors.printError("Details: {s}\n", .{msg});
|
||||
}
|
||||
core.output.errorMsgDetailed("jupyter.restore", error_msg, packet.error_details orelse packet.error_message orelse "");
|
||||
},
|
||||
else => {
|
||||
colors.printError("Unexpected response type\n", .{});
|
||||
core.output.errorMsg("jupyter.restore", "Unexpected response type");
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -139,50 +136,62 @@ pub fn defaultWorkspacePath(allocator: std.mem.Allocator, name: []const u8) ![]u
|
|||
}
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len < 1) {
|
||||
printUsagePackage();
|
||||
return;
|
||||
var flags = core.flags.CommonFlags{};
|
||||
|
||||
if (args.len == 0) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
// Global flags
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsagePackage();
|
||||
return;
|
||||
}
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
colors.printError("jupyter does not support --json\n", .{});
|
||||
printUsagePackage();
|
||||
return error.InvalidArgs;
|
||||
return printUsage();
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
}
|
||||
}
|
||||
|
||||
const action = args[0];
|
||||
const sub = args[0];
|
||||
|
||||
if (std.mem.eql(u8, action, "create")) {
|
||||
try createJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "start")) {
|
||||
try startJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "stop")) {
|
||||
try stopJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "status")) {
|
||||
try statusJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "list")) {
|
||||
try listServices(allocator);
|
||||
} else if (std.mem.eql(u8, action, "remove")) {
|
||||
try removeJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "restore")) {
|
||||
try restoreJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, action, "package")) {
|
||||
try packageCommands(args[1..]);
|
||||
if (std.mem.eql(u8, sub, "list")) {
|
||||
return listJupyter(allocator, args[1..], flags.json);
|
||||
} else if (std.mem.eql(u8, sub, "status")) {
|
||||
return statusJupyter(allocator, args[1..], flags.json);
|
||||
} else if (std.mem.eql(u8, sub, "launch")) {
|
||||
return launchJupyter(allocator, args[1..], flags.json);
|
||||
} else if (std.mem.eql(u8, sub, "terminate")) {
|
||||
return terminateJupyter(allocator, args[1..], flags.json);
|
||||
} else if (std.mem.eql(u8, sub, "save")) {
|
||||
return saveJupyter(allocator, args[1..], flags.json);
|
||||
} else if (std.mem.eql(u8, sub, "restore")) {
|
||||
return restoreJupyter(allocator, args[1..], flags.json);
|
||||
} else if (std.mem.eql(u8, sub, "install")) {
|
||||
return installJupyter(allocator, args[1..]);
|
||||
} else if (std.mem.eql(u8, sub, "uninstall")) {
|
||||
return uninstallJupyter(allocator, args[1..]);
|
||||
} else {
|
||||
colors.printError("Invalid action: {s}\n", .{action});
|
||||
core.output.errorMsg("jupyter", "Unknown subcommand");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
std.debug.print("Usage: ml jupyter <command> [args]\n", .{});
|
||||
std.debug.print("\nCommands:\n", .{});
|
||||
std.debug.print(" list List Jupyter services\n", .{});
|
||||
std.debug.print(" status Show Jupyter service status\n", .{});
|
||||
std.debug.print(" launch Launch a new Jupyter service\n", .{});
|
||||
std.debug.print(" terminate Terminate a Jupyter service\n", .{});
|
||||
std.debug.print(" save Save workspace\n", .{});
|
||||
std.debug.print(" restore Restore workspace\n", .{});
|
||||
std.debug.print(" install Install packages\n", .{});
|
||||
std.debug.print(" uninstall Uninstall packages\n", .{});
|
||||
}
|
||||
|
||||
fn printUsagePackage() void {
|
||||
colors.printError("Usage: ml jupyter package <action> [options]\n", .{});
|
||||
colors.printInfo("Actions:\n", .{});
|
||||
colors.printInfo(" list\n", .{});
|
||||
core.output.info("Actions:\n", .{});
|
||||
core.output.info("{s}", .{});
|
||||
colors.printInfo("Options:\n", .{});
|
||||
colors.printInfo(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
|
|
@ -505,8 +514,15 @@ fn removeJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
}
|
||||
|
||||
fn statusJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
_ = args; // Not used yet
|
||||
fn listJupyter(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
_ = args;
|
||||
_ = json;
|
||||
try listServices(allocator);
|
||||
}
|
||||
|
||||
fn statusJupyter(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
_ = args;
|
||||
_ = json;
|
||||
// Re-use listServices for now as status is part of list
|
||||
try listServices(allocator);
|
||||
}
|
||||
|
|
@ -850,3 +866,41 @@ fn packageCommands(args: []const []const u8) !void {
|
|||
colors.printError("Invalid package command: {s}\n", .{subcommand});
|
||||
}
|
||||
}
|
||||
|
||||
fn launchJupyter(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
_ = allocator;
|
||||
_ = args;
|
||||
_ = json;
|
||||
core.output.errorMsg("jupyter.launch", "Not implemented");
|
||||
return error.NotImplemented;
|
||||
}
|
||||
|
||||
fn terminateJupyter(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
_ = allocator;
|
||||
_ = args;
|
||||
_ = json;
|
||||
core.output.errorMsg("jupyter.terminate", "Not implemented");
|
||||
return error.NotImplemented;
|
||||
}
|
||||
|
||||
fn saveJupyter(allocator: std.mem.Allocator, args: []const []const u8, json: bool) !void {
|
||||
_ = allocator;
|
||||
_ = args;
|
||||
_ = json;
|
||||
core.output.errorMsg("jupyter.save", "Not implemented");
|
||||
return error.NotImplemented;
|
||||
}
|
||||
|
||||
fn installJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
_ = allocator;
|
||||
_ = args;
|
||||
core.output.errorMsg("jupyter.install", "Not implemented");
|
||||
return error.NotImplemented;
|
||||
}
|
||||
|
||||
fn uninstallJupyter(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
_ = allocator;
|
||||
_ = args;
|
||||
core.output.errorMsg("jupyter.uninstall", "Not implemented");
|
||||
return error.NotImplemented;
|
||||
}
|
||||
|
|
|
|||
192
cli/src/commands/log.zig
Normal file
192
cli/src/commands/log.zig
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
const std = @import("std");
|
||||
const config = @import("../config.zig");
|
||||
const core = @import("../core.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
|
||||
/// Logs command - fetch or stream run logs
|
||||
/// Usage:
|
||||
/// ml logs <run_id> # Fetch logs from local file or server
|
||||
/// ml logs <run_id> --follow # Stream logs from server
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var command_args = try core.flags.parseCommon(allocator, args, &flags);
|
||||
defer command_args.deinit(allocator);
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
if (flags.help) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
if (command_args.items.len < 1) {
|
||||
std.log.err("Usage: ml logs <run_id> [--follow]", .{});
|
||||
return error.MissingArgument;
|
||||
}
|
||||
|
||||
const target = command_args.items[0];
|
||||
const follow = core.flags.parseBoolFlag(command_args.items, "follow");
|
||||
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Detect mode
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
if (mode_result.warning) |w| {
|
||||
std.log.warn("{s}", .{w});
|
||||
}
|
||||
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
// Local mode: read from output.log file
|
||||
return try fetchLocalLogs(allocator, target, &cfg, flags.json);
|
||||
} else {
|
||||
// Online mode: fetch or stream from server
|
||||
if (follow) {
|
||||
return try streamServerLogs(allocator, target, cfg);
|
||||
} else {
|
||||
return try fetchServerLogs(allocator, target, cfg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fetchLocalLogs(allocator: std.mem.Allocator, target: []const u8, cfg: *const config.Config, json: bool) !void {
|
||||
// Resolve manifest path
|
||||
const manifest_path = manifest_lib.resolveManifestPath(target, cfg.artifact_path, allocator) catch |err| {
|
||||
if (err == error.ManifestNotFound) {
|
||||
std.log.err("Run not found: {s}", .{target});
|
||||
return error.RunNotFound;
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
// Read manifest to get artifact path
|
||||
var manifest = try manifest_lib.readManifest(manifest_path, allocator);
|
||||
defer manifest.deinit(allocator);
|
||||
|
||||
// Build output.log path
|
||||
const output_path = try std.fs.path.join(allocator, &[_][]const u8{
|
||||
manifest.artifact_path,
|
||||
"output.log",
|
||||
});
|
||||
defer allocator.free(output_path);
|
||||
|
||||
// Read output.log
|
||||
const content = std.fs.cwd().readFileAlloc(allocator, output_path, 10 * 1024 * 1024) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
std.log.err("No logs found for run: {s}", .{target});
|
||||
return error.LogsNotFound;
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(content);
|
||||
|
||||
if (json) {
|
||||
// Escape content for JSON
|
||||
var escaped: std.ArrayList(u8) = .empty;
|
||||
defer escaped.deinit(allocator);
|
||||
const writer = escaped.writer(allocator);
|
||||
|
||||
for (content) |c| {
|
||||
switch (c) {
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c >= 0x20 and c < 0x7f) {
|
||||
try writer.writeByte(c);
|
||||
} else {
|
||||
try writer.print("\\u{x:0>4}", .{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
std.debug.print("{{\"success\":true,\"run_id\":\"{s}\",\"logs\":\"{s}\"}}\n", .{
|
||||
manifest.run_id,
|
||||
escaped.items,
|
||||
});
|
||||
} else {
|
||||
std.debug.print("{s}\n", .{content});
|
||||
}
|
||||
}
|
||||
|
||||
fn fetchServerLogs(allocator: std.mem.Allocator, target: []const u8, cfg: config.Config) !void {
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendGetLogs(target, api_key_hash);
|
||||
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
std.debug.print("{s}\n", .{message});
|
||||
}
|
||||
|
||||
fn streamServerLogs(allocator: std.mem.Allocator, target: []const u8, cfg: config.Config) !void {
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
colors.printInfo("Streaming logs for: {s}\n", .{target});
|
||||
|
||||
try client.sendStreamLogs(target, api_key_hash);
|
||||
|
||||
// Stream loop
|
||||
while (true) {
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
|
||||
std.debug.print("{s}\n", .{message});
|
||||
continue;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
switch (packet.packet_type) {
|
||||
.data => {
|
||||
if (packet.data_payload) |payload| {
|
||||
std.debug.print("{s}\n", .{payload});
|
||||
}
|
||||
},
|
||||
.error_packet => {
|
||||
const err_msg = packet.error_message orelse "Stream error";
|
||||
colors.printError("Error: {s}\n", .{err_msg});
|
||||
return error.ServerError;
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
std.debug.print("Usage: ml logs <run_id> [options]\n\n", .{});
|
||||
std.debug.print("Fetch or stream run logs.\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --follow, -f Stream logs from server (online mode)\n", .{});
|
||||
std.debug.print(" --help, -h Show this help message\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n\n", .{});
|
||||
std.debug.print("Examples:\n", .{});
|
||||
std.debug.print(" ml logs abc123 # Fetch logs (local or server)\n", .{});
|
||||
std.debug.print(" ml logs abc123 --follow # Stream logs from server\n", .{});
|
||||
}
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
|
||||
/// Logs command - fetch and display job logs via WebSocket API
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
const target = argv[0];
|
||||
|
||||
// Parse optional flags
|
||||
var follow = false;
|
||||
var tail: ?usize = null;
|
||||
|
||||
var i: usize = 1;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const a = argv[i];
|
||||
if (std.mem.eql(u8, a, "-f") or std.mem.eql(u8, a, "--follow")) {
|
||||
follow = true;
|
||||
} else if (std.mem.eql(u8, a, "-n") and i + 1 < argv.len) {
|
||||
tail = try std.fmt.parseInt(usize, argv[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--tail") and i + 1 < argv.len) {
|
||||
tail = try std.fmt.parseInt(usize, argv[i + 1], 10);
|
||||
i += 1;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
colors.printInfo("Fetching logs for: {s}\n", .{target});
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Send appropriate request based on follow flag
|
||||
if (follow) {
|
||||
try client.sendStreamLogs(target, api_key_hash);
|
||||
} else {
|
||||
try client.sendGetLogs(target, api_key_hash);
|
||||
}
|
||||
|
||||
// Receive and display response
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
|
||||
// Fallback: treat as plain text response
|
||||
std.debug.print("{s}\n", .{message});
|
||||
return;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
switch (packet.packet_type) {
|
||||
.data => {
|
||||
if (packet.data_payload) |payload| {
|
||||
// Parse JSON response
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch {
|
||||
std.debug.print("{s}\n", .{payload});
|
||||
return;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value.object;
|
||||
|
||||
// Display logs
|
||||
if (root.get("logs")) |logs| {
|
||||
if (logs == .string) {
|
||||
std.debug.print("{s}\n", .{logs.string});
|
||||
}
|
||||
} else if (root.get("message")) |msg| {
|
||||
if (msg == .string) {
|
||||
colors.printInfo("{s}\n", .{msg.string});
|
||||
}
|
||||
}
|
||||
|
||||
// Show truncation warning if applicable
|
||||
if (root.get("truncated")) |truncated| {
|
||||
if (truncated == .bool and truncated.bool) {
|
||||
if (root.get("total_lines")) |total| {
|
||||
if (total == .integer) {
|
||||
colors.printWarning("\n[Output truncated. Total lines: {d}]\n", .{total.integer});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
.error_packet => {
|
||||
const err_msg = packet.error_message orelse "Unknown error";
|
||||
colors.printError("Error: {s}\n", .{err_msg});
|
||||
return error.ServerError;
|
||||
},
|
||||
else => {
|
||||
if (packet.success_message) |msg| {
|
||||
colors.printSuccess("{s}\n", .{msg});
|
||||
} else {
|
||||
colors.printInfo("Logs retrieved successfully\n", .{});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage:\n", .{});
|
||||
colors.printInfo(" ml logs <task_id|run_id|experiment_id> [-f|--follow] [-n <count>|--tail <count>]\n", .{});
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml logs abc123 # Show full logs\n", .{});
|
||||
colors.printInfo(" ml logs abc123 -f # Follow logs in real-time\n", .{});
|
||||
colors.printInfo(" ml logs abc123 -n 100 # Show last 100 lines\n", .{});
|
||||
}
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsage();
|
||||
return;
|
||||
}
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
std.debug.print("monitor does not support --json\n", .{});
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
std.debug.print("Launching TUI via SSH...\n", .{});
|
||||
|
||||
// Build remote command that exports config via env vars and runs the TUI
|
||||
var remote_cmd_buffer = std.ArrayList(u8){};
|
||||
defer remote_cmd_buffer.deinit(allocator);
|
||||
{
|
||||
const writer = remote_cmd_buffer.writer(allocator);
|
||||
try writer.print("cd {s} && ", .{config.worker_base});
|
||||
try writer.print(
|
||||
"FETCH_ML_CLI_HOST=\"{s}\" FETCH_ML_CLI_USER=\"{s}\" FETCH_ML_CLI_BASE=\"{s}\" ",
|
||||
.{ config.worker_host, config.worker_user, config.worker_base },
|
||||
);
|
||||
try writer.print(
|
||||
"FETCH_ML_CLI_PORT=\"{d}\" FETCH_ML_CLI_API_KEY=\"{s}\" ",
|
||||
.{ config.worker_port, config.api_key },
|
||||
);
|
||||
try writer.writeAll("./bin/tui");
|
||||
for (args) |arg| {
|
||||
try writer.print(" {s}", .{arg});
|
||||
}
|
||||
}
|
||||
const remote_cmd = try remote_cmd_buffer.toOwnedSlice();
|
||||
defer allocator.free(remote_cmd);
|
||||
|
||||
const ssh_cmd = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"ssh -t -p {d} {s}@{s} '{s}'",
|
||||
.{ config.worker_port, config.worker_user, config.worker_host, remote_cmd },
|
||||
);
|
||||
defer allocator.free(ssh_cmd);
|
||||
|
||||
var child = std.process.Child.init(&[_][]const u8{ "sh", "-c", ssh_cmd }, allocator);
|
||||
child.stdin_behavior = .Inherit;
|
||||
child.stdout_behavior = .Inherit;
|
||||
child.stderr_behavior = .Inherit;
|
||||
|
||||
const term = try child.spawnAndWait();
|
||||
|
||||
if (term.tag == .Exited and term.Exited != 0) {
|
||||
std.debug.print("TUI exited with code {d}\n", .{term.Exited});
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() void {
|
||||
std.debug.print("Usage: ml monitor [-- <tui-args...>]\n\n", .{});
|
||||
std.debug.print("Launches the remote TUI over SSH using ~/.ml/config.toml\n", .{});
|
||||
}
|
||||
|
|
@ -1,251 +0,0 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const sub = argv[0];
|
||||
if (std.mem.eql(u8, sub, "--help") or std.mem.eql(u8, sub, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!std.mem.eql(u8, sub, "set")) {
|
||||
colors.printError("Unknown subcommand: {s}\n", .{sub});
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (argv.len < 2) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const target = argv[1];
|
||||
|
||||
var hypothesis: ?[]const u8 = null;
|
||||
var context: ?[]const u8 = null;
|
||||
var intent: ?[]const u8 = null;
|
||||
var expected_outcome: ?[]const u8 = null;
|
||||
var parent_run: ?[]const u8 = null;
|
||||
var experiment_group: ?[]const u8 = null;
|
||||
var tags_csv: ?[]const u8 = null;
|
||||
var base_override: ?[]const u8 = null;
|
||||
var json_mode: bool = false;
|
||||
|
||||
var i: usize = 2;
|
||||
while (i < argv.len) : (i += 1) {
|
||||
const a = argv[i];
|
||||
if (std.mem.eql(u8, a, "--hypothesis")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
hypothesis = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--context")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
context = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--intent")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
intent = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--expected-outcome")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
expected_outcome = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--parent-run")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
parent_run = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--experiment-group")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
experiment_group = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--tags")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
tags_csv = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--base")) {
|
||||
if (i + 1 >= argv.len) return error.InvalidArgs;
|
||||
base_override = argv[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
if (hypothesis == null and context == null and intent == null and expected_outcome == null and parent_run == null and experiment_group == null and tags_csv == null) {
|
||||
colors.printError("No narrative fields provided.\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const resolved_base = base_override orelse cfg.worker_base;
|
||||
const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
colors.printError(
|
||||
"Could not locate run_manifest.json for '{s}'. Provide a path, or use --base <path> to scan finished/failed/running/pending.\n",
|
||||
.{target},
|
||||
);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const job_name = try manifest.readJobNameFromManifest(allocator, manifest_path);
|
||||
defer allocator.free(job_name);
|
||||
|
||||
const patch_json = try buildPatchJSON(
|
||||
allocator,
|
||||
hypothesis,
|
||||
context,
|
||||
intent,
|
||||
expected_outcome,
|
||||
parent_run,
|
||||
experiment_group,
|
||||
tags_csv,
|
||||
);
|
||||
defer allocator.free(patch_json);
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendSetRunNarrative(job_name, patch_json, api_key_hash);
|
||||
|
||||
if (json_mode) {
|
||||
const msg = try client.receiveMessage(allocator);
|
||||
defer allocator.free(msg);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(msg, allocator) catch {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{msg});
|
||||
return error.InvalidPacket;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
const Result = struct {
|
||||
ok: bool,
|
||||
job_name: []const u8,
|
||||
message: []const u8,
|
||||
error_code: ?u8 = null,
|
||||
error_message: ?[]const u8 = null,
|
||||
details: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
var out = io.stdoutWriter();
|
||||
if (packet.packet_type == .error_packet) {
|
||||
const res = Result{
|
||||
.ok = false,
|
||||
.job_name = job_name,
|
||||
.message = "",
|
||||
.error_code = @intFromEnum(packet.error_code.?),
|
||||
.error_message = packet.error_message orelse "",
|
||||
.details = packet.error_details orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return error.CommandFailed;
|
||||
}
|
||||
|
||||
const res = Result{
|
||||
.ok = true,
|
||||
.job_name = job_name,
|
||||
.message = packet.success_message orelse "",
|
||||
};
|
||||
try out.print("{f}\n", .{std.json.fmt(res, .{})});
|
||||
return;
|
||||
}
|
||||
|
||||
try client.receiveAndHandleResponse(allocator, "Narrative");
|
||||
|
||||
colors.printSuccess("Narrative updated\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
}
|
||||
|
||||
fn buildPatchJSON(
|
||||
allocator: std.mem.Allocator,
|
||||
hypothesis: ?[]const u8,
|
||||
context: ?[]const u8,
|
||||
intent: ?[]const u8,
|
||||
expected_outcome: ?[]const u8,
|
||||
parent_run: ?[]const u8,
|
||||
experiment_group: ?[]const u8,
|
||||
tags_csv: ?[]const u8,
|
||||
) ![]u8 {
|
||||
var out = std.ArrayList(u8).initCapacity(allocator, 256) catch return error.OutOfMemory;
|
||||
defer out.deinit(allocator);
|
||||
|
||||
var tags_list = std.ArrayList([]const u8).initCapacity(allocator, 8) catch return error.OutOfMemory;
|
||||
defer tags_list.deinit(allocator);
|
||||
|
||||
if (tags_csv) |csv| {
|
||||
var it = std.mem.splitScalar(u8, csv, ',');
|
||||
while (it.next()) |part| {
|
||||
const trimmed = std.mem.trim(u8, part, " \t\r\n");
|
||||
if (trimmed.len == 0) continue;
|
||||
try tags_list.append(allocator, trimmed);
|
||||
}
|
||||
}
|
||||
|
||||
const Patch = struct {
|
||||
hypothesis: ?[]const u8 = null,
|
||||
context: ?[]const u8 = null,
|
||||
intent: ?[]const u8 = null,
|
||||
expected_outcome: ?[]const u8 = null,
|
||||
parent_run: ?[]const u8 = null,
|
||||
experiment_group: ?[]const u8 = null,
|
||||
tags: ?[]const []const u8 = null,
|
||||
};
|
||||
|
||||
const patch = Patch{
|
||||
.hypothesis = hypothesis,
|
||||
.context = context,
|
||||
.intent = intent,
|
||||
.expected_outcome = expected_outcome,
|
||||
.parent_run = parent_run,
|
||||
.experiment_group = experiment_group,
|
||||
.tags = if (tags_list.items.len > 0) tags_list.items else null,
|
||||
};
|
||||
|
||||
const writer = out.writer(allocator);
|
||||
try writer.print("{f}", .{std.json.fmt(patch, .{})});
|
||||
return out.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml narrative set <path|run_id|task_id> [fields]\n", .{});
|
||||
colors.printInfo("\nFields:\n", .{});
|
||||
colors.printInfo(" --hypothesis \"...\"\n", .{});
|
||||
colors.printInfo(" --context \"...\"\n", .{});
|
||||
colors.printInfo(" --intent \"...\"\n", .{});
|
||||
colors.printInfo(" --expected-outcome \"...\"\n", .{});
|
||||
colors.printInfo(" --parent-run <id>\n", .{});
|
||||
colors.printInfo(" --experiment-group <name>\n", .{});
|
||||
colors.printInfo(" --tags a,b,c\n", .{});
|
||||
colors.printInfo(" --base <path>\n", .{});
|
||||
colors.printInfo(" --json\n", .{});
|
||||
}
|
||||
|
|
@ -3,20 +3,20 @@ const Config = @import("../config.zig").Config;
|
|||
const ws = @import("../net/ws/client.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var keep_count: ?u32 = null;
|
||||
var older_than_days: ?u32 = null;
|
||||
var json: bool = false;
|
||||
|
||||
// Parse flags
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
if (std.mem.eql(u8, args[i], "--help") or std.mem.eql(u8, args[i], "-h")) {
|
||||
printUsage();
|
||||
return;
|
||||
return printUsage();
|
||||
} else if (std.mem.eql(u8, args[i], "--json")) {
|
||||
json = true;
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--keep") and i + 1 < args.len) {
|
||||
keep_count = try std.fmt.parseInt(u32, args[i + 1], 10);
|
||||
i += 1;
|
||||
|
|
@ -26,8 +26,10 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.flags.json) .flags.json else .text);
|
||||
|
||||
if (keep_count == null and older_than_days == null) {
|
||||
printUsage();
|
||||
core.output.usage("prune", "ml prune --keep <n> | --older-than <days>");
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
|
|
@ -38,7 +40,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
|
||||
// Add confirmation prompt
|
||||
if (!json) {
|
||||
if (!flags.flags.json) {
|
||||
if (keep_count) |count| {
|
||||
if (!logging.confirm("This will permanently delete all but the {d} most recent experiments. Continue?", .{count})) {
|
||||
logging.info("Prune cancelled.\n", .{});
|
||||
|
|
@ -90,13 +92,13 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
// Parse prune response (simplified - assumes success/failure byte)
|
||||
if (response.len > 0) {
|
||||
if (response[0] == 0x00) {
|
||||
if (json) {
|
||||
if (flags.json) {
|
||||
std.debug.print("{\"ok\":true}\n", .{});
|
||||
} else {
|
||||
logging.success("✓ Prune operation completed successfully\n", .{});
|
||||
}
|
||||
} else {
|
||||
if (json) {
|
||||
if (flags.json) {
|
||||
std.debug.print("{\"ok\":false,\"error_code\":{d}}\n", .{response[0]});
|
||||
} else {
|
||||
logging.err("✗ Prune operation failed: error code {d}\n", .{response[0]});
|
||||
|
|
@ -104,7 +106,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
return error.PruneFailed;
|
||||
}
|
||||
} else {
|
||||
if (json) {
|
||||
if (flags.json) {
|
||||
std.debug.print("{\"ok\":true,\"note\":\"no_response\"}\n", .{});
|
||||
} else {
|
||||
logging.success("✓ Prune request sent (no response received)\n", .{});
|
||||
|
|
@ -117,6 +119,6 @@ fn printUsage() void {
|
|||
logging.info("Options:\n", .{});
|
||||
logging.info(" --keep <N> Keep N most recent experiments\n", .{});
|
||||
logging.info(" --older-than <days> Remove experiments older than N days\n", .{});
|
||||
logging.info(" --json Output machine-readable JSON\n", .{});
|
||||
logging.info(" --flags.json Output machine-readable JSON\n", .{});
|
||||
logging.info(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ const history = @import("../utils/history.zig");
|
|||
const crypto = @import("../utils/crypto.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const stdcrypto = std.crypto;
|
||||
const mode = @import("../mode.zig");
|
||||
const db = @import("../db.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
|
||||
pub const TrackingConfig = struct {
|
||||
mlflow: ?MLflowConfig = null,
|
||||
|
|
@ -42,6 +45,17 @@ pub const QueueOptions = struct {
|
|||
memory: u8 = 8,
|
||||
gpu: u8 = 0,
|
||||
gpu_memory: ?[]const u8 = null,
|
||||
// Narrative fields for research context
|
||||
hypothesis: ?[]const u8 = null,
|
||||
context: ?[]const u8 = null,
|
||||
intent: ?[]const u8 = null,
|
||||
expected_outcome: ?[]const u8 = null,
|
||||
experiment_group: ?[]const u8 = null,
|
||||
tags: ?[]const u8 = null,
|
||||
// Sandboxing options
|
||||
network_mode: ?[]const u8 = null,
|
||||
read_only: bool = false,
|
||||
secrets: std.ArrayList([]const u8),
|
||||
};
|
||||
|
||||
fn resolveCommitHexOrPrefix(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
|
||||
|
|
@ -92,6 +106,45 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
return;
|
||||
}
|
||||
|
||||
// Load config for mode detection
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
// Detect mode early to provide clear error for offline
|
||||
const mode_result = try mode.detect(allocator, config);
|
||||
|
||||
// Check for --rerun flag
|
||||
var rerun_id: ?[]const u8 = null;
|
||||
for (args, 0..) |arg, i| {
|
||||
if (std.mem.eql(u8, arg, "--rerun") and i + 1 < args.len) {
|
||||
rerun_id = args[i + 1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If --rerun is specified, handle re-queueing
|
||||
if (rerun_id) |id| {
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml queue --rerun requires server connection\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
return try handleRerun(allocator, id, args, config);
|
||||
}
|
||||
|
||||
// Regular queue - requires server
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml queue requires server connection (use 'ml run' for local execution)\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
|
||||
// Continue with regular queue logic...
|
||||
try executeQueue(allocator, args, config);
|
||||
}
|
||||
|
||||
fn executeQueue(allocator: std.mem.Allocator, args: []const []const u8, config: Config) !void {
|
||||
// Support batch operations - multiple job names
|
||||
var job_names = std.ArrayList([]const u8).initCapacity(allocator, 10) catch |err| {
|
||||
colors.printError("Failed to allocate job list: {}\n", .{err});
|
||||
|
|
@ -106,13 +159,6 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
var args_override: ?[]const u8 = null;
|
||||
var note_override: ?[]const u8 = null;
|
||||
|
||||
// Load configuration to get defaults
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
// Initialize options with config defaults
|
||||
var options = QueueOptions{
|
||||
.cpu = config.default_cpu,
|
||||
|
|
@ -122,7 +168,9 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
.dry_run = config.default_dry_run,
|
||||
.validate = config.default_validate,
|
||||
.json = config.default_json,
|
||||
.secrets = std.ArrayList([]const u8).empty,
|
||||
};
|
||||
defer options.secrets.deinit(allocator);
|
||||
priority = config.default_priority;
|
||||
|
||||
// Tracking configuration
|
||||
|
|
@ -254,6 +302,32 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
} else if (std.mem.eql(u8, arg, "--note") and i + 1 < pre.len) {
|
||||
note_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--hypothesis") and i + 1 < pre.len) {
|
||||
options.hypothesis = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--context") and i + 1 < pre.len) {
|
||||
options.context = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--intent") and i + 1 < pre.len) {
|
||||
options.intent = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--expected-outcome") and i + 1 < pre.len) {
|
||||
options.expected_outcome = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--experiment-group") and i + 1 < pre.len) {
|
||||
options.experiment_group = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--tags") and i + 1 < pre.len) {
|
||||
options.tags = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--network") and i + 1 < pre.len) {
|
||||
options.network_mode = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, arg, "--read-only")) {
|
||||
options.read_only = true;
|
||||
} else if (std.mem.eql(u8, arg, "--secret") and i + 1 < pre.len) {
|
||||
try options.secrets.append(allocator, pre[i + 1]);
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
// This is a job name
|
||||
|
|
@ -352,6 +426,35 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
}
|
||||
|
||||
/// Handle --rerun flag: re-queue a completed run
|
||||
fn handleRerun(allocator: std.mem.Allocator, run_id: []const u8, args: []const []const u8, cfg: Config) !void {
|
||||
_ = args; // Override args not implemented yet
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Send rerun request to server
|
||||
try client.sendRerunRequest(run_id, api_key_hash);
|
||||
|
||||
// Wait for response
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
// Parse response (simplified)
|
||||
if (std.mem.indexOf(u8, message, "success") != null) {
|
||||
colors.printSuccess("✓ Re-queued run {s}\n", .{run_id[0..8]});
|
||||
} else {
|
||||
colors.printError("Failed to re-queue: {s}\n", .{message});
|
||||
return error.RerunFailed;
|
||||
}
|
||||
}
|
||||
|
||||
fn generateCommitID(allocator: std.mem.Allocator) ![]const u8 {
|
||||
var bytes: [20]u8 = undefined;
|
||||
stdcrypto.random.bytes(&bytes);
|
||||
|
|
@ -378,6 +481,10 @@ fn queueSingleJob(
|
|||
};
|
||||
defer if (commit_override == null) allocator.free(commit_id);
|
||||
|
||||
// Build narrative JSON if any narrative fields are set
|
||||
const narrative_json = buildNarrativeJson(allocator, options) catch null;
|
||||
defer if (narrative_json) |j| allocator.free(j);
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
|
|
@ -386,11 +493,23 @@ fn queueSingleJob(
|
|||
|
||||
const commit_hex = try crypto.encodeHexLower(allocator, commit_id);
|
||||
defer allocator.free(commit_hex);
|
||||
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Check for existing job with same commit (incremental queue)
|
||||
if (!options.force) {
|
||||
const existing = try checkExistingJob(allocator, job_name, commit_id, api_key_hash, config);
|
||||
if (existing) |ex| {
|
||||
defer allocator.free(ex);
|
||||
// Server already has this job - handle duplicate response
|
||||
try handleDuplicateResponse(allocator, ex, job_name, commit_hex, options);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("Queueing job '{s}' with commit {s}...\n", .{ job_name, commit_hex });
|
||||
|
||||
// Connect to WebSocket and send queue message
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
|
@ -407,13 +526,43 @@ fn queueSingleJob(
|
|||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (tracking_json.len > 0) {
|
||||
// Build combined metadata JSON with tracking and/or narrative
|
||||
const combined_json = blk: {
|
||||
if (tracking_json.len > 0 and narrative_json != null) {
|
||||
// Merge tracking and narrative
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{");
|
||||
try writer.writeAll(tracking_json[1 .. tracking_json.len - 1]); // Remove outer braces
|
||||
try writer.writeAll(",");
|
||||
try writer.writeAll("\"narrative\":");
|
||||
try writer.writeAll(narrative_json.?);
|
||||
try writer.writeAll("}");
|
||||
break :blk try buf.toOwnedSlice(allocator);
|
||||
} else if (tracking_json.len > 0) {
|
||||
break :blk try allocator.dupe(u8, tracking_json);
|
||||
} else if (narrative_json) |nj| {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{\"narrative\":");
|
||||
try writer.writeAll(nj);
|
||||
try writer.writeAll("}");
|
||||
break :blk try buf.toOwnedSlice(allocator);
|
||||
} else {
|
||||
break :blk "";
|
||||
}
|
||||
};
|
||||
defer if (combined_json.len > 0 and combined_json.ptr != tracking_json.ptr) allocator.free(combined_json);
|
||||
|
||||
if (combined_json.len > 0) {
|
||||
try client.sendQueueJobWithTrackingAndResources(
|
||||
job_name,
|
||||
commit_id,
|
||||
priority,
|
||||
api_key_hash,
|
||||
tracking_json,
|
||||
combined_json,
|
||||
options.cpu,
|
||||
options.memory,
|
||||
options.gpu,
|
||||
|
|
@ -536,6 +685,7 @@ fn queueSingleJob(
|
|||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml queue <job-name> [job-name ...] [options]\n", .{});
|
||||
colors.printInfo(" ml queue --rerun <run_id> # Re-queue a completed run\n", .{});
|
||||
colors.printInfo("\nBasic Options:\n", .{});
|
||||
colors.printInfo(" --commit <id> Specify commit ID\n", .{});
|
||||
colors.printInfo(" --priority <num> Set priority (0-255, default: 5)\n", .{});
|
||||
|
|
@ -547,7 +697,15 @@ fn printUsage() !void {
|
|||
colors.printInfo(" --args <string> Extra runner args (sent to worker as task.Args)\n", .{});
|
||||
colors.printInfo(" --note <string> Human notes (stored in run manifest as metadata.note)\n", .{});
|
||||
colors.printInfo(" -- <args...> Extra runner args (alternative to --args)\n", .{});
|
||||
colors.printInfo("\nResearch Narrative:\n", .{});
|
||||
colors.printInfo(" --hypothesis <text> Research hypothesis being tested\n", .{});
|
||||
colors.printInfo(" --context <text> Background context for this experiment\n", .{});
|
||||
colors.printInfo(" --intent <text> What you're trying to accomplish\n", .{});
|
||||
colors.printInfo(" --expected-outcome <text> What you expect to happen\n", .{});
|
||||
colors.printInfo(" --experiment-group <name> Group related experiments\n", .{});
|
||||
colors.printInfo(" --tags <csv> Comma-separated tags (e.g., ablation,batch-size)\n", .{});
|
||||
colors.printInfo("\nSpecial Modes:\n", .{});
|
||||
colors.printInfo(" --rerun <run_id> Re-queue a completed local run to server\n", .{});
|
||||
colors.printInfo(" --dry-run Show what would be queued\n", .{});
|
||||
colors.printInfo(" --validate Validate experiment without queuing\n", .{});
|
||||
colors.printInfo(" --explain Explain what will happen\n", .{});
|
||||
|
|
@ -561,11 +719,20 @@ fn printUsage() !void {
|
|||
colors.printInfo(" --wandb-project <prj> Set Wandb project\n", .{});
|
||||
colors.printInfo(" --wandb-entity <ent> Set Wandb entity\n", .{});
|
||||
|
||||
colors.printInfo("\nSandboxing:\n", .{});
|
||||
colors.printInfo(" --network <mode> Network mode: none, bridge, slirp4netns\n", .{});
|
||||
colors.printInfo(" --read-only Mount root filesystem as read-only\n", .{});
|
||||
colors.printInfo(" --secret <name> Inject secret as env var (can repeat)\n", .{});
|
||||
|
||||
colors.printInfo("\nExamples:\n", .{});
|
||||
colors.printInfo(" ml queue my_job # Queue a job\n", .{});
|
||||
colors.printInfo(" ml queue my_job --dry-run # Preview submission\n", .{});
|
||||
colors.printInfo(" ml queue my_job --validate # Validate locally\n", .{});
|
||||
colors.printInfo(" ml queue --rerun abc123 # Re-queue completed run\n", .{});
|
||||
colors.printInfo(" ml status --watch # Watch queue + prewarm\n", .{});
|
||||
colors.printInfo("\nResearch Examples:\n", .{});
|
||||
colors.printInfo(" ml queue train.py --hypothesis 'LR scaling improves convergence' \n", .{});
|
||||
colors.printInfo(" --context 'Following paper XYZ' --tags ablation,lr-scaling\n", .{});
|
||||
}
|
||||
|
||||
pub fn formatNextSteps(allocator: std.mem.Allocator, job_name: []const u8, commit_hex: []const u8) ![]u8 {
|
||||
|
|
@ -597,13 +764,23 @@ fn explainJob(
|
|||
commit_display = enc;
|
||||
}
|
||||
|
||||
// Build narrative JSON for display
|
||||
const narrative_json = buildNarrativeJson(allocator, options) catch null;
|
||||
defer if (narrative_json) |j| allocator.free(j);
|
||||
|
||||
if (options.json) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
var buffer: [4096]u8 = undefined;
|
||||
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"explain\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch unreachable;
|
||||
try stdout_file.writeAll(formatted);
|
||||
try writeJSONNullableString(&stdout_file, options.gpu_memory);
|
||||
try stdout_file.writeAll("}}\n");
|
||||
if (narrative_json) |nj| {
|
||||
try stdout_file.writeAll("},\"narrative\":");
|
||||
try stdout_file.writeAll(nj);
|
||||
try stdout_file.writeAll("}\n");
|
||||
} else {
|
||||
try stdout_file.writeAll("}}\n");
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
colors.printInfo("Job Explanation:\n", .{});
|
||||
|
|
@ -616,7 +793,30 @@ fn explainJob(
|
|||
colors.printInfo(" GPU: {d} device(s)\n", .{options.gpu});
|
||||
colors.printInfo(" GPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
|
||||
|
||||
colors.printInfo(" Action: Job would be queued for execution\n", .{});
|
||||
// Display narrative if provided
|
||||
if (narrative_json != null) {
|
||||
colors.printInfo("\n Research Narrative:\n", .{});
|
||||
if (options.hypothesis) |h| {
|
||||
colors.printInfo(" Hypothesis: {s}\n", .{h});
|
||||
}
|
||||
if (options.context) |c| {
|
||||
colors.printInfo(" Context: {s}\n", .{c});
|
||||
}
|
||||
if (options.intent) |i| {
|
||||
colors.printInfo(" Intent: {s}\n", .{i});
|
||||
}
|
||||
if (options.expected_outcome) |eo| {
|
||||
colors.printInfo(" Expected Outcome: {s}\n", .{eo});
|
||||
}
|
||||
if (options.experiment_group) |eg| {
|
||||
colors.printInfo(" Experiment Group: {s}\n", .{eg});
|
||||
}
|
||||
if (options.tags) |t| {
|
||||
colors.printInfo(" Tags: {s}\n", .{t});
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\n Action: Job would be queued for execution\n", .{});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -689,13 +889,23 @@ fn dryRunJob(
|
|||
commit_display = enc;
|
||||
}
|
||||
|
||||
// Build narrative JSON for display
|
||||
const narrative_json = buildNarrativeJson(allocator, options) catch null;
|
||||
defer if (narrative_json) |j| allocator.free(j);
|
||||
|
||||
if (options.json) {
|
||||
const stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
var buffer: [4096]u8 = undefined;
|
||||
const formatted = std.fmt.bufPrint(&buffer, "{{\"action\":\"dry_run\",\"job_name\":\"{s}\",\"commit_id\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory_gb\":{d},\"gpu\":{d},\"gpu_memory\":", .{ job_name, commit_display, priority, options.cpu, options.memory, options.gpu }) catch unreachable;
|
||||
try stdout_file.writeAll(formatted);
|
||||
try writeJSONNullableString(&stdout_file, options.gpu_memory);
|
||||
try stdout_file.writeAll("}},\"would_queue\":true}}\n");
|
||||
if (narrative_json) |nj| {
|
||||
try stdout_file.writeAll("},\"narrative\":");
|
||||
try stdout_file.writeAll(nj);
|
||||
try stdout_file.writeAll(",\"would_queue\":true}}\n");
|
||||
} else {
|
||||
try stdout_file.writeAll("},\"would_queue\":true}}\n");
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
colors.printInfo("Dry Run - Job Queue Preview:\n", .{});
|
||||
|
|
@ -708,7 +918,30 @@ fn dryRunJob(
|
|||
colors.printInfo(" GPU: {d} device(s)\n", .{options.gpu});
|
||||
colors.printInfo(" GPU Memory: {s}\n", .{options.gpu_memory orelse "auto"});
|
||||
|
||||
colors.printInfo(" Action: Would queue job\n", .{});
|
||||
// Display narrative if provided
|
||||
if (narrative_json != null) {
|
||||
colors.printInfo("\n Research Narrative:\n", .{});
|
||||
if (options.hypothesis) |h| {
|
||||
colors.printInfo(" Hypothesis: {s}\n", .{h});
|
||||
}
|
||||
if (options.context) |c| {
|
||||
colors.printInfo(" Context: {s}\n", .{c});
|
||||
}
|
||||
if (options.intent) |i| {
|
||||
colors.printInfo(" Intent: {s}\n", .{i});
|
||||
}
|
||||
if (options.expected_outcome) |eo| {
|
||||
colors.printInfo(" Expected Outcome: {s}\n", .{eo});
|
||||
}
|
||||
if (options.experiment_group) |eg| {
|
||||
colors.printInfo(" Experiment Group: {s}\n", .{eg});
|
||||
}
|
||||
if (options.tags) |t| {
|
||||
colors.printInfo(" Tags: {s}\n", .{t});
|
||||
}
|
||||
}
|
||||
|
||||
colors.printInfo("\n Action: Would queue job\n", .{});
|
||||
colors.printInfo(" Estimated queue time: 2-5 minutes\n", .{});
|
||||
colors.printSuccess(" ✓ Dry run completed - no job was actually queued\n", .{});
|
||||
}
|
||||
|
|
@ -923,3 +1156,114 @@ fn handleDuplicateResponse(
|
|||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
|
||||
// buildNarrativeJson creates a JSON object from narrative fields
|
||||
fn buildNarrativeJson(allocator: std.mem.Allocator, options: *const QueueOptions) !?[]u8 {
|
||||
// Check if any narrative field is set
|
||||
if (options.hypothesis == null and
|
||||
options.context == null and
|
||||
options.intent == null and
|
||||
options.expected_outcome == null and
|
||||
options.experiment_group == null and
|
||||
options.tags == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("{");
|
||||
|
||||
var first = true;
|
||||
|
||||
if (options.hypothesis) |h| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"hypothesis\":");
|
||||
try writeJSONString(writer, h);
|
||||
}
|
||||
|
||||
if (options.context) |c| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"context\":");
|
||||
try writeJSONString(writer, c);
|
||||
}
|
||||
|
||||
if (options.intent) |i| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"intent\":");
|
||||
try writeJSONString(writer, i);
|
||||
}
|
||||
|
||||
if (options.expected_outcome) |eo| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"expected_outcome\":");
|
||||
try writeJSONString(writer, eo);
|
||||
}
|
||||
|
||||
if (options.experiment_group) |eg| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"experiment_group\":");
|
||||
try writeJSONString(writer, eg);
|
||||
}
|
||||
|
||||
if (options.tags) |t| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.writeAll("\"tags\":");
|
||||
try writeJSONString(writer, t);
|
||||
}
|
||||
|
||||
try writer.writeAll("}");
|
||||
|
||||
return try buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
/// Check if a job with the same commit_id already exists on the server
|
||||
/// Returns: Optional JSON response from server if duplicate found
|
||||
fn checkExistingJob(
|
||||
allocator: std.mem.Allocator,
|
||||
job_name: []const u8,
|
||||
commit_id: []const u8,
|
||||
api_key_hash: []const u8,
|
||||
config: Config,
|
||||
) !?[]const u8 {
|
||||
// Connect to server and query for existing job
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Send query for existing job
|
||||
try client.sendQueryJobByCommit(job_name, commit_id, api_key_hash);
|
||||
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
// Parse response
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch |err| {
|
||||
// If JSON parse fails, treat as no duplicate found
|
||||
std.log.debug("Failed to parse check response: {}", .{err});
|
||||
return null;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value.object;
|
||||
|
||||
// Check if job exists
|
||||
if (root.get("exists")) |exists| {
|
||||
if (!exists.bool) return null;
|
||||
|
||||
// Job exists - copy the full response for caller
|
||||
return try allocator.dupe(u8, message);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
|
|
|||
3
cli/src/commands/queue/index.zig
Normal file
3
cli/src/commands/queue/index.zig
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub const parse = @import("queue/parse.zig");
|
||||
pub const validate = @import("queue/validate.zig");
|
||||
pub const submit = @import("queue/submit.zig");
|
||||
177
cli/src/commands/queue/parse.zig
Normal file
177
cli/src/commands/queue/parse.zig
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Parse job template from command line arguments
|
||||
pub const JobTemplate = struct {
|
||||
job_names: std.ArrayList([]const u8),
|
||||
commit_id_override: ?[]const u8,
|
||||
priority: u8,
|
||||
snapshot_id: ?[]const u8,
|
||||
snapshot_sha256: ?[]const u8,
|
||||
args_override: ?[]const u8,
|
||||
note_override: ?[]const u8,
|
||||
cpu: u8,
|
||||
memory: u8,
|
||||
gpu: u8,
|
||||
gpu_memory: ?[]const u8,
|
||||
dry_run: bool,
|
||||
validate: bool,
|
||||
explain: bool,
|
||||
json: bool,
|
||||
force: bool,
|
||||
runner_args_start: ?usize,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) JobTemplate {
|
||||
return .{
|
||||
.job_names = std.ArrayList([]const u8).init(allocator),
|
||||
.commit_id_override = null,
|
||||
.priority = 5,
|
||||
.snapshot_id = null,
|
||||
.snapshot_sha256 = null,
|
||||
.args_override = null,
|
||||
.note_override = null,
|
||||
.cpu = 2,
|
||||
.memory = 8,
|
||||
.gpu = 0,
|
||||
.gpu_memory = null,
|
||||
.dry_run = false,
|
||||
.validate = false,
|
||||
.explain = false,
|
||||
.json = false,
|
||||
.force = false,
|
||||
.runner_args_start = null,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *JobTemplate, allocator: std.mem.Allocator) void {
|
||||
self.job_names.deinit(allocator);
|
||||
}
|
||||
};
|
||||
|
||||
/// Parse command arguments into a job template
|
||||
pub fn parseArgs(allocator: std.mem.Allocator, args: []const []const u8) !JobTemplate {
|
||||
var template = JobTemplate.init(allocator);
|
||||
errdefer template.deinit(allocator);
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
|
||||
if (std.mem.eql(u8, arg, "--")) {
|
||||
template.runner_args_start = i + 1;
|
||||
break;
|
||||
} else if (std.mem.eql(u8, arg, "--commit-id")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.commit_id_override = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--priority")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.priority = std.fmt.parseInt(u8, args[i + 1], 10) catch 5;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.snapshot_id = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--snapshot-sha256")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.snapshot_sha256 = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--args")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.args_override = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--note")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.note_override = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--cpu")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.cpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 2;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--memory")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.memory = std.fmt.parseInt(u8, args[i + 1], 10) catch 8;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--gpu")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.gpu = std.fmt.parseInt(u8, args[i + 1], 10) catch 0;
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--gpu-memory")) {
|
||||
if (i + 1 < args.len) {
|
||||
template.gpu_memory = args[i + 1];
|
||||
i += 1;
|
||||
}
|
||||
} else if (std.mem.eql(u8, arg, "--dry-run")) {
|
||||
template.dry_run = true;
|
||||
} else if (std.mem.eql(u8, arg, "--validate")) {
|
||||
template.validate = true;
|
||||
} else if (std.mem.eql(u8, arg, "--explain")) {
|
||||
template.explain = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
template.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--force")) {
|
||||
template.force = true;
|
||||
} else if (!std.mem.startsWith(u8, arg, "-")) {
|
||||
// Positional argument - job name
|
||||
try template.job_names.append(arg);
|
||||
}
|
||||
}
|
||||
|
||||
return template;
|
||||
}
|
||||
|
||||
/// Get runner args from the parsed template
|
||||
pub fn getRunnerArgs(self: JobTemplate, all_args: []const []const u8) []const []const u8 {
|
||||
if (self.runner_args_start) |start| {
|
||||
if (start < all_args.len) {
|
||||
return all_args[start..];
|
||||
}
|
||||
}
|
||||
return &[_][]const u8{};
|
||||
}
|
||||
|
||||
/// Resolve commit ID from prefix or full hash
|
||||
pub fn resolveCommitId(allocator: std.mem.Allocator, base_path: []const u8, input: []const u8) ![]u8 {
|
||||
if (input.len < 7 or input.len > 40) return error.InvalidArgs;
|
||||
for (input) |c| {
|
||||
if (!std.ascii.isHex(c)) return error.InvalidArgs;
|
||||
}
|
||||
|
||||
if (input.len == 40) {
|
||||
return allocator.dupe(u8, input);
|
||||
}
|
||||
|
||||
var dir = if (std.fs.path.isAbsolute(base_path))
|
||||
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
|
||||
else
|
||||
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
var found: ?[]u8 = null;
|
||||
errdefer if (found) |s| allocator.free(s);
|
||||
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
const name = entry.name;
|
||||
if (name.len != 40) continue;
|
||||
if (!std.mem.startsWith(u8, name, input)) continue;
|
||||
for (name) |c| {
|
||||
if (!std.ascii.isHex(c)) break;
|
||||
} else {
|
||||
if (found != null) return error.InvalidArgs;
|
||||
found = try allocator.dupe(u8, name);
|
||||
}
|
||||
}
|
||||
|
||||
if (found) |s| return s;
|
||||
return error.FileNotFound;
|
||||
}
|
||||
200
cli/src/commands/queue/submit.zig
Normal file
200
cli/src/commands/queue/submit.zig
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
const std = @import("std");
|
||||
const ws = @import("../../net/ws/client.zig");
|
||||
const protocol = @import("../../net/protocol.zig");
|
||||
const crypto = @import("../../utils/crypto.zig");
|
||||
const Config = @import("../../config.zig").Config;
|
||||
const core = @import("../../core.zig");
|
||||
const history = @import("../../utils/history.zig");
|
||||
|
||||
/// Job submission configuration
|
||||
pub const SubmitConfig = struct {
|
||||
job_names: []const []const u8,
|
||||
commit_id: ?[]const u8,
|
||||
priority: u8,
|
||||
snapshot_id: ?[]const u8,
|
||||
snapshot_sha256: ?[]const u8,
|
||||
args_override: ?[]const u8,
|
||||
note_override: ?[]const u8,
|
||||
cpu: u8,
|
||||
memory: u8,
|
||||
gpu: u8,
|
||||
gpu_memory: ?[]const u8,
|
||||
dry_run: bool,
|
||||
force: bool,
|
||||
runner_args: []const []const u8,
|
||||
|
||||
pub fn estimateTotalJobs(self: SubmitConfig) usize {
|
||||
return self.job_names.len;
|
||||
}
|
||||
};
|
||||
|
||||
/// Submission result
|
||||
pub const SubmitResult = struct {
|
||||
success: bool,
|
||||
job_count: usize,
|
||||
errors: std.ArrayList([]const u8),
|
||||
|
||||
pub fn init() SubmitResult {
|
||||
return .{
|
||||
.success = true,
|
||||
.job_count = 0,
|
||||
.errors = .empty,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *SubmitResult, allocator: std.mem.Allocator) void {
|
||||
for (self.errors.items) |err| {
|
||||
allocator.free(err);
|
||||
}
|
||||
self.errors.deinit(allocator);
|
||||
}
|
||||
};
|
||||
|
||||
/// Submit jobs to the server
|
||||
pub fn submitJobs(
|
||||
allocator: std.mem.Allocator,
|
||||
config: Config,
|
||||
submit_config: SubmitConfig,
|
||||
json: bool,
|
||||
) !SubmitResult {
|
||||
var result = SubmitResult.init(allocator);
|
||||
errdefer result.deinit(allocator);
|
||||
|
||||
// Dry run mode - just print what would be submitted
|
||||
if (submit_config.dry_run) {
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"queue.submit\",\"dry_run\":true,\"jobs\":[", .{});
|
||||
for (submit_config.job_names, 0..) |name, i| {
|
||||
if (i > 0) std.debug.print(",", .{});
|
||||
std.debug.print("\"{s}\"", .{name});
|
||||
}
|
||||
std.debug.print("],\"total\":{d}}}}}\n", .{submit_config.job_names.len});
|
||||
} else {
|
||||
std.debug.print("[DRY RUN] Would submit {d} jobs:\n", .{submit_config.job_names.len});
|
||||
for (submit_config.job_names) |name| {
|
||||
std.debug.print(" - {s}\n", .{name});
|
||||
}
|
||||
}
|
||||
result.job_count = submit_config.job_names.len;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Get WebSocket URL
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
// Hash API key
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Connect to server
|
||||
var client = ws.Client.connect(allocator, ws_url, config.api_key) catch |err| {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Failed to connect: {}", .{err});
|
||||
result.addError(msg);
|
||||
result.success = false;
|
||||
return result;
|
||||
};
|
||||
defer client.close();
|
||||
|
||||
// Submit each job
|
||||
for (submit_config.job_names) |job_name| {
|
||||
submitSingleJob(
|
||||
allocator,
|
||||
&client,
|
||||
api_key_hash,
|
||||
job_name,
|
||||
submit_config,
|
||||
&result,
|
||||
) catch |err| {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Failed to submit {s}: {}", .{ job_name, err });
|
||||
result.addError(msg);
|
||||
result.success = false;
|
||||
};
|
||||
}
|
||||
|
||||
// Save to history if successful
|
||||
if (result.success and result.job_count > 0) {
|
||||
if (submit_config.commit_id) |commit_id| {
|
||||
for (submit_config.job_names) |job_name| {
|
||||
history.saveEntry(allocator, job_name, commit_id) catch {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Submit a single job
|
||||
fn submitSingleJob(
|
||||
allocator: std.mem.Allocator,
|
||||
client: *ws.Client,
|
||||
_: []const u8,
|
||||
job_name: []const u8,
|
||||
submit_config: SubmitConfig,
|
||||
result: *SubmitResult,
|
||||
) !void {
|
||||
// Build job submission payload
|
||||
var payload = std.ArrayList(u8).init(allocator);
|
||||
defer payload.deinit();
|
||||
|
||||
const writer = payload.writer();
|
||||
try writer.print(
|
||||
"{{\"job_name\":\"{s}\",\"priority\":{d},\"resources\":{{\"cpu\":{d},\"memory\":{d},\"gpu\":{d}",
|
||||
.{ job_name, submit_config.priority, submit_config.cpu, submit_config.memory, submit_config.gpu },
|
||||
);
|
||||
|
||||
if (submit_config.gpu_memory) |gm| {
|
||||
try writer.print(",\"gpu_memory\":\"{s}\"", .{gm});
|
||||
}
|
||||
|
||||
try writer.print("}}", .{});
|
||||
|
||||
if (submit_config.commit_id) |cid| {
|
||||
try writer.print(",\"commit_id\":\"{s}\"", .{cid});
|
||||
}
|
||||
|
||||
if (submit_config.snapshot_id) |sid| {
|
||||
try writer.print(",\"snapshot_id\":\"{s}\"", .{sid});
|
||||
}
|
||||
|
||||
if (submit_config.note_override) |note| {
|
||||
try writer.print(",\"note\":\"{s}\"", .{note});
|
||||
}
|
||||
|
||||
try writer.print("}}", .{});
|
||||
|
||||
// Send job submission
|
||||
client.sendMessage(payload.items) catch |err| {
|
||||
return err;
|
||||
};
|
||||
|
||||
result.job_count += 1;
|
||||
}
|
||||
|
||||
/// Print submission results
|
||||
pub fn printResults(result: SubmitResult, json: bool) void {
|
||||
if (json) {
|
||||
const status = if (result.success) "true" else "false";
|
||||
std.debug.print("{{\"success\":{s},\"command\":\"queue.submit\",\"data\":{{\"submitted\":{d}", .{ status, result.job_count });
|
||||
|
||||
if (result.errors.items.len > 0) {
|
||||
std.debug.print(",\"errors\":[", .{});
|
||||
for (result.errors.items, 0..) |err, i| {
|
||||
if (i > 0) std.debug.print(",", .{});
|
||||
std.debug.print("\"{s}\"", .{err});
|
||||
}
|
||||
std.debug.print("]", .{});
|
||||
}
|
||||
|
||||
std.debug.print("}}}}\n", .{});
|
||||
} else {
|
||||
if (result.success) {
|
||||
std.debug.print("Successfully submitted {d} jobs\n", .{result.job_count});
|
||||
} else {
|
||||
std.debug.print("Failed to submit jobs ({d} errors)\n", .{result.errors.items.len});
|
||||
for (result.errors.items) |err| {
|
||||
std.debug.print(" Error: {s}\n", .{err});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
161
cli/src/commands/queue/validate.zig
Normal file
161
cli/src/commands/queue/validate.zig
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Validation errors for queue operations
|
||||
pub const ValidationError = error{
|
||||
MissingJobName,
|
||||
InvalidCommitId,
|
||||
InvalidSnapshotId,
|
||||
InvalidResourceLimits,
|
||||
DuplicateJobName,
|
||||
InvalidPriority,
|
||||
};
|
||||
|
||||
/// Validation result
|
||||
pub const ValidationResult = struct {
|
||||
valid: bool,
|
||||
errors: std.ArrayList([]const u8),
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) ValidationResult {
|
||||
return .{
|
||||
.valid = true,
|
||||
.errors = std.ArrayList([]const u8).init(allocator),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *ValidationResult, allocator: std.mem.Allocator) void {
|
||||
for (self.errors.items) |err| {
|
||||
allocator.free(err);
|
||||
}
|
||||
self.errors.deinit(allocator);
|
||||
}
|
||||
|
||||
pub fn addError(self: *ValidationResult, allocator: std.mem.Allocator, msg: []const u8) void {
|
||||
self.valid = false;
|
||||
const copy = allocator.dupe(u8, msg) catch return;
|
||||
self.errors.append(copy) catch {
|
||||
allocator.free(copy);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Validate job name format
|
||||
pub fn validateJobName(name: []const u8) bool {
|
||||
if (name.len == 0 or name.len > 128) return false;
|
||||
|
||||
for (name) |c| {
|
||||
if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Validate commit ID format (40 character hex)
|
||||
pub fn validateCommitId(id: []const u8) bool {
|
||||
if (id.len != 40) return false;
|
||||
for (id) |c| {
|
||||
if (!std.ascii.isHex(c)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Validate snapshot ID format
|
||||
pub fn validateSnapshotId(id: []const u8) bool {
|
||||
if (id.len == 0 or id.len > 64) return false;
|
||||
for (id) |c| {
|
||||
if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Validate resource limits
|
||||
pub fn validateResources(cpu: u8, memory: u8, gpu: u8) ValidationError!void {
|
||||
if (cpu == 0 or cpu > 128) {
|
||||
return error.InvalidResourceLimits;
|
||||
}
|
||||
if (memory == 0 or memory > 1024) {
|
||||
return error.InvalidResourceLimits;
|
||||
}
|
||||
if (gpu > 16) {
|
||||
return error.InvalidResourceLimits;
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate priority value (1-10)
|
||||
pub fn validatePriority(priority: u8) ValidationError!void {
|
||||
if (priority < 1 or priority > 10) {
|
||||
return error.InvalidPriority;
|
||||
}
|
||||
}
|
||||
|
||||
/// Full validation for job template
|
||||
pub fn validateJobTemplate(
|
||||
allocator: std.mem.Allocator,
|
||||
job_names: []const []const u8,
|
||||
commit_id: ?[]const u8,
|
||||
cpu: u8,
|
||||
memory: u8,
|
||||
gpu: u8,
|
||||
) !ValidationResult {
|
||||
var result = ValidationResult.init(allocator);
|
||||
errdefer result.deinit(allocator);
|
||||
|
||||
// Check job names
|
||||
if (job_names.len == 0) {
|
||||
result.addError(allocator, "At least one job name is required");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
var seen = std.StringHashMap(void).init(allocator);
|
||||
defer seen.deinit();
|
||||
|
||||
for (job_names) |name| {
|
||||
if (!validateJobName(name)) {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Invalid job name: {s}", .{name});
|
||||
result.addError(allocator, msg);
|
||||
allocator.free(msg);
|
||||
}
|
||||
|
||||
if (seen.contains(name)) {
|
||||
const msg = try std.fmt.allocPrint(allocator, "Duplicate job name: {s}", .{name});
|
||||
result.addError(allocator, msg);
|
||||
allocator.free(msg);
|
||||
} else {
|
||||
try seen.put(name, {});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate commit ID if provided
|
||||
if (commit_id) |id| {
|
||||
if (!validateCommitId(id)) {
|
||||
result.addError(allocator, "Invalid commit ID format (expected 40 character hex)");
|
||||
}
|
||||
}
|
||||
|
||||
// Validate resources
|
||||
validateResources(cpu, memory, gpu) catch {
|
||||
result.addError(allocator, "Invalid resource limits");
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Print validation errors
|
||||
pub fn printValidationErrors(result: ValidationResult, json: bool) void {
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":false,\"command\":\"queue.validate\",\"errors\":[", .{});
|
||||
for (result.errors.items, 0..) |err, i| {
|
||||
if (i > 0) std.debug.print(",", .{});
|
||||
std.debug.print("\"{s}\"", .{err});
|
||||
}
|
||||
std.debug.print("]}}\n", .{});
|
||||
} else {
|
||||
std.debug.print("Validation failed:\n", .{});
|
||||
for (result.errors.items) |err| {
|
||||
std.debug.print(" - {s}\n", .{err});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,328 +0,0 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const manifest = @import("../utils/manifest.zig");
|
||||
const json = @import("../utils/json.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void {
|
||||
if (argv.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
if (std.mem.eql(u8, argv[0], "--help") or std.mem.eql(u8, argv[0], "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
}
|
||||
|
||||
const target = argv[0];
|
||||
|
||||
// Split args at "--".
|
||||
var sep_index: ?usize = null;
|
||||
for (argv, 0..) |a, i| {
|
||||
if (std.mem.eql(u8, a, "--")) {
|
||||
sep_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const pre = argv[1..(sep_index orelse argv.len)];
|
||||
const post = if (sep_index) |i| argv[(i + 1)..] else argv[0..0];
|
||||
|
||||
const cfg = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Defaults
|
||||
var job_name_override: ?[]const u8 = null;
|
||||
var priority: u8 = cfg.default_priority;
|
||||
var cpu: u8 = cfg.default_cpu;
|
||||
var memory: u8 = cfg.default_memory;
|
||||
var gpu: u8 = cfg.default_gpu;
|
||||
var gpu_memory: ?[]const u8 = cfg.default_gpu_memory;
|
||||
var args_override: ?[]const u8 = null;
|
||||
var note_override: ?[]const u8 = null;
|
||||
var force: bool = false;
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < pre.len) : (i += 1) {
|
||||
const a = pre[i];
|
||||
if (std.mem.eql(u8, a, "--name") and i + 1 < pre.len) {
|
||||
job_name_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--priority") and i + 1 < pre.len) {
|
||||
priority = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--cpu") and i + 1 < pre.len) {
|
||||
cpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--memory") and i + 1 < pre.len) {
|
||||
memory = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--gpu") and i + 1 < pre.len) {
|
||||
gpu = try std.fmt.parseInt(u8, pre[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--gpu-memory") and i + 1 < pre.len) {
|
||||
gpu_memory = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--args") and i + 1 < pre.len) {
|
||||
args_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--note") and i + 1 < pre.len) {
|
||||
note_override = pre[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, a, "--force")) {
|
||||
force = true;
|
||||
} else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) {
|
||||
try printUsage();
|
||||
return;
|
||||
} else {
|
||||
colors.printError("Unknown option: {s}\n", .{a});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
}
|
||||
|
||||
var args_joined: []const u8 = "";
|
||||
if (post.len > 0) {
|
||||
var buf: std.ArrayList(u8) = .{};
|
||||
defer buf.deinit(allocator);
|
||||
for (post, 0..) |a, idx| {
|
||||
if (idx > 0) try buf.append(allocator, ' ');
|
||||
try buf.appendSlice(allocator, a);
|
||||
}
|
||||
args_joined = try buf.toOwnedSlice(allocator);
|
||||
}
|
||||
defer if (post.len > 0) allocator.free(args_joined);
|
||||
|
||||
const args_final: []const u8 = if (args_override) |a| a else args_joined;
|
||||
const note_final: []const u8 = if (note_override) |n| n else "";
|
||||
|
||||
// Target can be:
|
||||
// - commit_id (40-hex) or commit_id prefix (>=7 hex) resolvable under worker_base
|
||||
// - run_id/task_id/path (resolved to run_manifest.json to read commit_id)
|
||||
var commit_hex: []const u8 = "";
|
||||
var commit_hex_owned: ?[]u8 = null;
|
||||
defer if (commit_hex_owned) |s| allocator.free(s);
|
||||
|
||||
var commit_bytes: []u8 = &[_]u8{};
|
||||
var commit_bytes_allocated = false;
|
||||
defer if (commit_bytes_allocated) allocator.free(commit_bytes);
|
||||
|
||||
if (target.len >= 7 and target.len <= 40 and isHexLowerOrUpper(target)) {
|
||||
if (target.len == 40) {
|
||||
commit_hex = target;
|
||||
} else {
|
||||
commit_hex_owned = try resolveCommitPrefix(allocator, cfg.worker_base, target);
|
||||
commit_hex = commit_hex_owned.?;
|
||||
}
|
||||
|
||||
const decoded = crypto.decodeHex(allocator, commit_hex) catch {
|
||||
commit_hex = "";
|
||||
commit_hex_owned = null;
|
||||
return error.InvalidCommitId;
|
||||
};
|
||||
if (decoded.len != 20) {
|
||||
allocator.free(decoded);
|
||||
commit_hex = "";
|
||||
commit_hex_owned = null;
|
||||
} else {
|
||||
commit_bytes = decoded;
|
||||
commit_bytes_allocated = true;
|
||||
}
|
||||
}
|
||||
|
||||
var job_name = blk: {
|
||||
if (job_name_override) |n| break :blk n;
|
||||
break :blk "requeue";
|
||||
};
|
||||
|
||||
if (commit_hex.len == 0) {
|
||||
const manifest_path = try manifest.resolvePathWithBase(allocator, target, cfg.worker_base);
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const data = try manifest.readFileAlloc(allocator, manifest_path);
|
||||
defer allocator.free(data);
|
||||
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, data, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value != .object) return error.InvalidManifest;
|
||||
const root = parsed.value.object;
|
||||
|
||||
commit_hex = json.getString(root, "commit_id") orelse "";
|
||||
if (commit_hex.len != 40) {
|
||||
colors.printError("run manifest missing commit_id\n", .{});
|
||||
return error.InvalidManifest;
|
||||
}
|
||||
|
||||
if (job_name_override == null) {
|
||||
const j = json.getString(root, "job_name") orelse "";
|
||||
if (j.len > 0) job_name = j;
|
||||
}
|
||||
|
||||
const b = try crypto.decodeHex(allocator, commit_hex);
|
||||
if (b.len != 20) {
|
||||
allocator.free(b);
|
||||
return error.InvalidCommitId;
|
||||
}
|
||||
commit_bytes = b;
|
||||
commit_bytes_allocated = true;
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
if (note_final.len > 0) {
|
||||
try client.sendQueueJobWithArgsNoteAndResources(
|
||||
job_name,
|
||||
commit_bytes,
|
||||
priority,
|
||||
api_key_hash,
|
||||
args_final,
|
||||
note_final,
|
||||
force,
|
||||
cpu,
|
||||
memory,
|
||||
gpu,
|
||||
gpu_memory,
|
||||
);
|
||||
} else {
|
||||
try client.sendQueueJobWithArgsAndResources(
|
||||
job_name,
|
||||
commit_bytes,
|
||||
priority,
|
||||
api_key_hash,
|
||||
args_final,
|
||||
force,
|
||||
cpu,
|
||||
memory,
|
||||
gpu,
|
||||
gpu_memory,
|
||||
);
|
||||
}
|
||||
|
||||
// Receive response with duplicate detection
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
|
||||
if (message.len > 0 and message[0] == '{') {
|
||||
try handleDuplicateResponse(allocator, message, job_name, commit_hex);
|
||||
} else {
|
||||
colors.printInfo("Server response: {s}\n", .{message});
|
||||
}
|
||||
return;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
switch (packet.packet_type) {
|
||||
.success => {
|
||||
colors.printSuccess("Queued requeue\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
colors.printInfo("Commit: {s}\n", .{commit_hex});
|
||||
},
|
||||
.data => {
|
||||
if (packet.data_payload) |payload| {
|
||||
try handleDuplicateResponse(allocator, payload, job_name, commit_hex);
|
||||
}
|
||||
},
|
||||
.error_packet => {
|
||||
const err_msg = packet.error_message orelse "Unknown error";
|
||||
colors.printError("Error: {s}\n", .{err_msg});
|
||||
return error.ServerError;
|
||||
},
|
||||
else => {
|
||||
try client.handleResponsePacket(packet, "Requeue");
|
||||
colors.printSuccess("Queued requeue\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
colors.printInfo("Commit: {s}\n", .{commit_hex});
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn handleDuplicateResponse(
|
||||
allocator: std.mem.Allocator,
|
||||
payload: []const u8,
|
||||
job_name: []const u8,
|
||||
commit_hex: []const u8,
|
||||
) !void {
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, payload, .{}) catch {
|
||||
colors.printInfo("Server response: {s}\n", .{payload});
|
||||
return;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value.object;
|
||||
const is_dup = root.get("duplicate") != null and root.get("duplicate").?.bool;
|
||||
if (!is_dup) {
|
||||
colors.printSuccess("Queued requeue\n", .{});
|
||||
colors.printInfo("Job: {s}\n", .{job_name});
|
||||
colors.printInfo("Commit: {s}\n", .{commit_hex});
|
||||
return;
|
||||
}
|
||||
|
||||
const existing_id = root.get("existing_id").?.string;
|
||||
const status = root.get("status").?.string;
|
||||
|
||||
if (std.mem.eql(u8, status, "queued") or std.mem.eql(u8, status, "running")) {
|
||||
colors.printInfo("\n→ Identical job already in progress: {s}\n", .{existing_id[0..8]});
|
||||
colors.printInfo("\n Watch: ml watch {s}\n", .{existing_id[0..8]});
|
||||
} else if (std.mem.eql(u8, status, "completed")) {
|
||||
colors.printInfo("\n→ Identical job already completed: {s}\n", .{existing_id[0..8]});
|
||||
colors.printInfo("\n Inspect: ml experiment show {s}\n", .{existing_id[0..8]});
|
||||
colors.printInfo(" Rerun: ml requeue {s} --force\n", .{commit_hex});
|
||||
} else if (std.mem.eql(u8, status, "failed")) {
|
||||
colors.printWarning("\n→ Identical job previously failed: {s}\n", .{existing_id[0..8]});
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage:\n", .{});
|
||||
colors.printInfo(" ml requeue <commit_id|run_id|task_id|path> [--name <job>] [--priority <n>] [--cpu <n>] [--memory <gb>] [--gpu <n>] [--gpu-memory <gb>] [--args <string>] [--note <string>] [--force] -- <args...>\n", .{});
|
||||
}
|
||||
|
||||
fn isHexLowerOrUpper(s: []const u8) bool {
|
||||
for (s) |c| {
|
||||
if (!std.ascii.isHex(c)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
fn resolveCommitPrefix(allocator: std.mem.Allocator, base_path: []const u8, prefix: []const u8) ![]u8 {
|
||||
var dir = if (std.fs.path.isAbsolute(base_path))
|
||||
try std.fs.openDirAbsolute(base_path, .{ .iterate = true })
|
||||
else
|
||||
try std.fs.cwd().openDir(base_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
var found: ?[]u8 = null;
|
||||
errdefer if (found) |s| allocator.free(s);
|
||||
|
||||
while (try it.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
const name = entry.name;
|
||||
if (name.len != 40) continue;
|
||||
if (!std.mem.startsWith(u8, name, prefix)) continue;
|
||||
if (!isHexLowerOrUpper(name)) continue;
|
||||
|
||||
if (found != null) {
|
||||
colors.printError("Ambiguous commit prefix: {s}\n", .{prefix});
|
||||
return error.InvalidCommitId;
|
||||
}
|
||||
found = try allocator.dupe(u8, name);
|
||||
}
|
||||
|
||||
if (found) |s| return s;
|
||||
colors.printError("No commit matches prefix: {s}\n", .{prefix});
|
||||
return error.FileNotFound;
|
||||
}
|
||||
425
cli/src/commands/run.zig
Normal file
425
cli/src/commands/run.zig
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
const std = @import("std");
|
||||
const db = @import("../db.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const core = @import("../core.zig");
|
||||
const config = @import("../config.zig");
|
||||
|
||||
extern fn execvp(path: [*:0]const u8, argv: [*]const ?[*:0]const u8) c_int;
|
||||
extern fn waitpid(pid: c_int, status: *c_int, flags: c_int) c_int;
|
||||
|
||||
// Get current environment from libc
|
||||
extern var environ: [*]const ?[*:0]const u8;
|
||||
|
||||
// Inline macros for wait status parsing (not available as extern on macOS)
|
||||
fn WIFEXITED(status: c_int) c_int {
|
||||
return if ((status & 0x7F) == 0) 1 else 0;
|
||||
}
|
||||
fn WEXITSTATUS(status: c_int) c_int {
|
||||
return (status >> 8) & 0xFF;
|
||||
}
|
||||
fn WIFSIGNALED(status: c_int) c_int {
|
||||
return if (((status & 0x7F) != 0x7F) and ((status & 0x7F) != 0)) 1 else 0;
|
||||
}
|
||||
fn WTERMSIG(status: c_int) c_int {
|
||||
return status & 0x7F;
|
||||
}
|
||||
const Manifest = manifest_lib.RunManifest;
|
||||
|
||||
/// Run command - always executes locally
|
||||
/// Usage:
|
||||
/// ml run # Use entrypoint from config + args
|
||||
/// ml run --lr 0.001 # Args appended to entrypoint
|
||||
/// ml run -- python train.py # Explicit command
|
||||
pub fn execute(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var command_args = try core.flags.parseCommon(allocator, args, &flags);
|
||||
defer command_args.deinit(allocator);
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
if (flags.help) {
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Parse command: entrypoint + args, or explicit -- command
|
||||
const command = try resolveCommand(allocator, &cfg, command_args.items);
|
||||
defer freeCommand(allocator, command);
|
||||
|
||||
// Generate run_id
|
||||
const run_id = try db.generateUUID(allocator);
|
||||
defer allocator.free(run_id);
|
||||
|
||||
// Determine experiment name
|
||||
const experiment_name = if (cfg.experiment) |exp| exp.name else "default";
|
||||
|
||||
// Build artifact path
|
||||
const artifact_path = try std.fs.path.join(allocator, &[_][]const u8{
|
||||
cfg.artifact_path,
|
||||
experiment_name,
|
||||
run_id,
|
||||
});
|
||||
defer allocator.free(artifact_path);
|
||||
|
||||
// Create run directory
|
||||
std.fs.makeDirAbsolute(artifact_path) catch |err| {
|
||||
if (err != error.PathAlreadyExists) {
|
||||
std.log.err("Failed to create run directory: {}", .{err});
|
||||
return error.MkdirFailed;
|
||||
}
|
||||
};
|
||||
|
||||
// Get DB path and initialize if needed (lazy bootstrap)
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
var database = try initOrOpenDB(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
// Write run manifest (status=RUNNING)
|
||||
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{
|
||||
artifact_path,
|
||||
"run_manifest.json",
|
||||
});
|
||||
defer allocator.free(manifest_path);
|
||||
|
||||
const timestamp = try db.currentTimestamp(allocator);
|
||||
defer allocator.free(timestamp);
|
||||
|
||||
var manifest = Manifest.init(allocator);
|
||||
manifest.run_id = run_id;
|
||||
manifest.experiment = experiment_name;
|
||||
manifest.command = try std.mem.join(allocator, " ", command);
|
||||
manifest.args = try duplicateStrings(allocator, command);
|
||||
manifest.started_at = try allocator.dupe(u8, timestamp);
|
||||
manifest.status = "RUNNING";
|
||||
manifest.artifact_path = artifact_path;
|
||||
manifest.synced = false;
|
||||
|
||||
// Insert run into database
|
||||
const run_name = try std.fmt.allocPrint(allocator, "run-{s}", .{run_id[0..8]});
|
||||
defer allocator.free(run_name);
|
||||
|
||||
const sql = "INSERT INTO ml_runs (run_id, experiment_id, name, status, start_time, synced) VALUES (?, ?, ?, 'RUNNING', ?, 0);";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
try db.DB.bindText(stmt, 2, experiment_name);
|
||||
try db.DB.bindText(stmt, 3, run_name);
|
||||
try db.DB.bindText(stmt, 4, timestamp);
|
||||
_ = try db.DB.step(stmt);
|
||||
|
||||
// Write manifest
|
||||
try manifest_lib.writeManifest(manifest, manifest_path, allocator);
|
||||
|
||||
// Fork and execute
|
||||
const output_log_path = try std.fs.path.join(allocator, &[_][]const u8{
|
||||
artifact_path,
|
||||
"output.log",
|
||||
});
|
||||
defer allocator.free(output_log_path);
|
||||
|
||||
// Execute and capture
|
||||
const exit_code = try executeAndCapture(
|
||||
allocator,
|
||||
command,
|
||||
output_log_path,
|
||||
&database,
|
||||
run_id,
|
||||
);
|
||||
|
||||
// Update run status in database
|
||||
const end_time = try db.currentTimestamp(allocator);
|
||||
defer allocator.free(end_time);
|
||||
|
||||
const status = if (exit_code == 0) "FINISHED" else "FAILED";
|
||||
|
||||
const update_sql = "UPDATE ml_runs SET status = ?, end_time = ?, exit_code = ?, pid = NULL WHERE run_id = ?;";
|
||||
const update_stmt = try database.prepare(update_sql);
|
||||
defer db.DB.finalize(update_stmt);
|
||||
|
||||
try db.DB.bindText(update_stmt, 1, status);
|
||||
try db.DB.bindText(update_stmt, 2, end_time);
|
||||
try db.DB.bindInt64(update_stmt, 3, exit_code);
|
||||
try db.DB.bindText(update_stmt, 4, run_id);
|
||||
_ = try db.DB.step(update_stmt);
|
||||
|
||||
// Update manifest
|
||||
try manifest_lib.updateManifestStatus(manifest_path, status, exit_code, allocator);
|
||||
|
||||
// Checkpoint WAL
|
||||
database.checkpointOnExit();
|
||||
|
||||
// Print result
|
||||
if (flags.json) {
|
||||
std.debug.print("{{\"success\":true,\"run_id\":\"{s}\",\"status\":\"{s}\",\"exit_code\":{d}}}\n", .{
|
||||
run_id,
|
||||
status,
|
||||
exit_code,
|
||||
});
|
||||
} else {
|
||||
colors.printSuccess("✓ Run {s} complete ({s})\n", .{ run_id[0..8], status });
|
||||
if (cfg.sync_uri.len > 0) {
|
||||
colors.printInfo("↑ queued for sync\n", .{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve command from entrypoint + args, or explicit -- command
|
||||
fn resolveCommand(allocator: std.mem.Allocator, cfg: *const config.Config, args: []const []const u8) ![][]const u8 {
|
||||
// Check for explicit -- separator
|
||||
var double_dash_idx: ?usize = null;
|
||||
for (args, 0..) |arg, i| {
|
||||
if (std.mem.eql(u8, arg, "--")) {
|
||||
double_dash_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (double_dash_idx) |idx| {
|
||||
// Explicit command after --
|
||||
if (idx + 1 >= args.len) {
|
||||
std.log.err("No command provided after --", .{});
|
||||
return error.NoCommand;
|
||||
}
|
||||
return try allocator.dupe([]const u8, args[idx + 1 ..]);
|
||||
}
|
||||
|
||||
// Use entrypoint from config + args
|
||||
if (cfg.experiment) |exp| {
|
||||
if (exp.entrypoint.len > 0) {
|
||||
// Split entrypoint on spaces
|
||||
var argv: std.ArrayList([]const u8) = .empty;
|
||||
|
||||
// Parse entrypoint (split on spaces)
|
||||
var iter = std.mem.splitScalar(u8, exp.entrypoint, ' ');
|
||||
while (iter.next()) |part| {
|
||||
if (part.len > 0) {
|
||||
try argv.append(allocator, try allocator.dupe(u8, part));
|
||||
}
|
||||
}
|
||||
|
||||
// Append args
|
||||
for (args) |arg| {
|
||||
try argv.append(allocator, try allocator.dupe(u8, arg));
|
||||
}
|
||||
|
||||
return try argv.toOwnedSlice(allocator);
|
||||
}
|
||||
}
|
||||
|
||||
// No entrypoint configured and no explicit command
|
||||
std.log.err("No entrypoint configured. Set entrypoint in .fetchml/config.toml or use: ml run -- <command>", .{});
|
||||
return error.NoEntrypoint;
|
||||
}
|
||||
|
||||
/// Free command array
|
||||
fn freeCommand(allocator: std.mem.Allocator, command: [][]const u8) void {
|
||||
for (command) |arg| {
|
||||
allocator.free(arg);
|
||||
}
|
||||
allocator.free(command);
|
||||
}
|
||||
|
||||
/// Duplicate array of strings
|
||||
fn duplicateStrings(allocator: std.mem.Allocator, strings: []const []const u8) ![][]const u8 {
|
||||
const result = try allocator.alloc([]const u8, strings.len);
|
||||
for (strings, 0..) |s, i| {
|
||||
result[i] = try allocator.dupe(u8, s);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Initialize or open database (lazy bootstrap)
|
||||
fn initOrOpenDB(allocator: std.mem.Allocator, db_path: []const u8) !db.DB {
|
||||
const db_exists = blk: {
|
||||
std.fs.accessAbsolute(db_path, .{}) catch |err| {
|
||||
if (err == error.FileNotFound) break :blk false;
|
||||
};
|
||||
break :blk true;
|
||||
};
|
||||
|
||||
const database = try db.DB.init(allocator, db_path);
|
||||
|
||||
if (!db_exists) {
|
||||
std.log.info("local mode active — tracking to {s}", .{db_path});
|
||||
}
|
||||
|
||||
return database;
|
||||
}
|
||||
|
||||
/// Execute command and capture output, parsing FETCHML_METRIC lines
|
||||
fn executeAndCapture(
|
||||
allocator: std.mem.Allocator,
|
||||
command: []const []const u8,
|
||||
output_path: []const u8,
|
||||
database: *db.DB,
|
||||
run_id: []const u8,
|
||||
) !i32 {
|
||||
// Create output file
|
||||
var output_file = try std.fs.cwd().createFile(output_path, .{});
|
||||
defer output_file.close();
|
||||
|
||||
// Create pipe for stdout
|
||||
const pipe = try std.posix.pipe();
|
||||
defer {
|
||||
std.posix.close(pipe[0]);
|
||||
std.posix.close(pipe[1]);
|
||||
}
|
||||
|
||||
// Fork child process
|
||||
const pid = try std.posix.fork();
|
||||
|
||||
if (pid == 0) {
|
||||
// Child process
|
||||
std.posix.close(pipe[0]); // Close read end
|
||||
|
||||
// Redirect stdout to pipe
|
||||
_ = std.posix.dup2(pipe[1], std.posix.STDOUT_FILENO) catch std.process.exit(1);
|
||||
_ = std.posix.dup2(pipe[1], std.posix.STDERR_FILENO) catch std.process.exit(1);
|
||||
std.posix.close(pipe[1]);
|
||||
|
||||
// Execute command using execvp (uses current environ)
|
||||
const c_err = execvp(@ptrCast(command[0].ptr), @ptrCast(command.ptr));
|
||||
std.log.err("Failed to execute {s}: {}", .{ command[0], c_err });
|
||||
std.process.exit(1);
|
||||
unreachable;
|
||||
}
|
||||
|
||||
// Parent process
|
||||
std.posix.close(pipe[1]); // Close write end
|
||||
|
||||
// Store PID in database
|
||||
const pid_sql = "UPDATE ml_runs SET pid = ? WHERE run_id = ?;";
|
||||
const pid_stmt = try database.prepare(pid_sql);
|
||||
defer db.DB.finalize(pid_stmt);
|
||||
try db.DB.bindInt64(pid_stmt, 1, pid);
|
||||
try db.DB.bindText(pid_stmt, 2, run_id);
|
||||
_ = try db.DB.step(pid_stmt);
|
||||
|
||||
// Read from pipe and parse FETCHML_METRIC lines
|
||||
var buf: [4096]u8 = undefined;
|
||||
var line_buf: [1024]u8 = undefined;
|
||||
var line_len: usize = 0;
|
||||
|
||||
while (true) {
|
||||
const bytes_read = std.posix.read(pipe[0], &buf) catch |err| {
|
||||
if (err == error.WouldBlock or err == error.BrokenPipe) break;
|
||||
break;
|
||||
};
|
||||
|
||||
if (bytes_read == 0) break;
|
||||
|
||||
// Write to output file
|
||||
try output_file.writeAll(buf[0..bytes_read]);
|
||||
|
||||
// Parse lines
|
||||
for (buf[0..bytes_read]) |byte| {
|
||||
if (byte == '\n' or line_len >= line_buf.len - 1) {
|
||||
if (line_len > 0) {
|
||||
line_buf[line_len] = 0;
|
||||
const line = line_buf[0..line_len];
|
||||
try parseAndLogMetric(allocator, line, database, run_id);
|
||||
line_len = 0;
|
||||
}
|
||||
} else {
|
||||
line_buf[line_len] = byte;
|
||||
line_len += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for child
|
||||
var status: c_int = 0;
|
||||
_ = waitpid(@intCast(pid), &status, 0);
|
||||
|
||||
// Parse exit code
|
||||
if (WIFEXITED(status) != 0) {
|
||||
return WEXITSTATUS(status);
|
||||
} else if (WIFSIGNALED(status) != 0) {
|
||||
return 128 + WTERMSIG(status);
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Parse FETCHML_METRIC line and log to database
|
||||
/// Format: FETCHML_METRIC key=value [step=N]
|
||||
fn parseAndLogMetric(
|
||||
allocator: std.mem.Allocator,
|
||||
line: []const u8,
|
||||
database: *db.DB,
|
||||
run_id: []const u8,
|
||||
) !void {
|
||||
const trimmed = std.mem.trim(u8, line, " \t\r");
|
||||
|
||||
// Check prefix
|
||||
const prefix = "FETCHML_METRIC";
|
||||
if (!std.mem.startsWith(u8, trimmed, prefix)) return;
|
||||
|
||||
// Get the rest after prefix
|
||||
var rest = trimmed[prefix.len..];
|
||||
rest = std.mem.trimLeft(u8, rest, " \t");
|
||||
|
||||
// Parse key=value
|
||||
var iter = std.mem.splitScalar(u8, rest, ' ');
|
||||
const kv_part = iter.next() orelse return;
|
||||
|
||||
var kv_iter = std.mem.splitScalar(u8, kv_part, '=');
|
||||
const key = kv_iter.next() orelse return;
|
||||
const value_str = kv_iter.next() orelse return;
|
||||
|
||||
// Validate key: [a-zA-Z][a-zA-Z0-9_]*
|
||||
if (key.len == 0) return;
|
||||
const first_char = key[0];
|
||||
if (!std.ascii.isAlphabetic(first_char)) return;
|
||||
for (key[1..]) |c| {
|
||||
if (!std.ascii.isAlphanumeric(c) and c != '_') return;
|
||||
}
|
||||
|
||||
// Parse value
|
||||
const value = std.fmt.parseFloat(f64, value_str) catch return;
|
||||
|
||||
// Parse optional step
|
||||
var step: i64 = 0;
|
||||
while (iter.next()) |part| {
|
||||
if (std.mem.startsWith(u8, part, "step=")) {
|
||||
const step_str = part[5..];
|
||||
step = std.fmt.parseInt(i64, step_str, 10) catch 0;
|
||||
if (step < 0) step = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert metric
|
||||
const sql = "INSERT INTO ml_metrics (run_id, key, value, step) VALUES (?, ?, ?, ?);";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
try db.DB.bindText(stmt, 2, key);
|
||||
try db.DB.bindDouble(stmt, 3, value);
|
||||
try db.DB.bindInt64(stmt, 4, step);
|
||||
_ = try db.DB.step(stmt);
|
||||
|
||||
_ = allocator;
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
std.debug.print("Usage: ml run [options] [args...]\n", .{});
|
||||
std.debug.print(" ml run -- <command> [args...]\n\n", .{});
|
||||
std.debug.print("Execute a run locally with experiment tracking.\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --help, -h Show this help message\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n\n", .{});
|
||||
std.debug.print("Examples:\n", .{});
|
||||
std.debug.print(" ml run # Use entrypoint from config\n", .{});
|
||||
std.debug.print(" ml run --lr 0.001 # Append args to entrypoint\n", .{});
|
||||
std.debug.print(" ml run -- python train.py # Run explicit command\n", .{});
|
||||
}
|
||||
|
|
@ -4,10 +4,12 @@ const ws = @import("../net/ws/client.zig");
|
|||
const crypto = @import("../utils/crypto.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const auth = @import("../utils/auth.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub const StatusOptions = struct {
|
||||
json: bool = false,
|
||||
watch: bool = false,
|
||||
tui: bool = false,
|
||||
limit: ?usize = null,
|
||||
watch_interval: u32 = 5,
|
||||
};
|
||||
|
|
@ -22,17 +24,20 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
options.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--watch")) {
|
||||
options.watch = true;
|
||||
} else if (std.mem.eql(u8, arg, "--tui")) {
|
||||
options.tui = true;
|
||||
} else if (std.mem.eql(u8, arg, "--limit") and i + 1 < args.len) {
|
||||
options.limit = try std.fmt.parseInt(usize, args[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.startsWith(u8, arg, "--watch-interval=")) {
|
||||
options.watch_interval = try std.fmt.parseInt(u32, arg[17..], 10);
|
||||
} else if (std.mem.eql(u8, arg, "--help")) {
|
||||
try printUsage();
|
||||
return;
|
||||
return printUsage();
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (options.json) .json else .text);
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
|
|
@ -52,6 +57,8 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
|
||||
if (options.watch) {
|
||||
try runWatchMode(allocator, config, user_context, options);
|
||||
} else if (options.tui) {
|
||||
try runTuiMode(allocator, config, args);
|
||||
} else {
|
||||
try runSingleStatus(allocator, config, user_context, options);
|
||||
}
|
||||
|
|
@ -72,11 +79,11 @@ fn runSingleStatus(allocator: std.mem.Allocator, config: Config, user_context: a
|
|||
}
|
||||
|
||||
fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: auth.UserContext, options: StatusOptions) !void {
|
||||
colors.printInfo("Starting watch mode (interval: {d}s). Press Ctrl+C to stop.\n", .{options.watch_interval});
|
||||
core.output.info("Starting watch mode (interval: {d}s). Press Ctrl+C to stop.\n", .{options.watch_interval});
|
||||
|
||||
while (true) {
|
||||
if (!options.json) {
|
||||
colors.printInfo("\n=== FetchML Status - {s} ===\n", .{user_context.name});
|
||||
core.output.info("\n=== FetchML Status - {s} ===", .{user_context.name});
|
||||
}
|
||||
|
||||
try runSingleStatus(allocator, config, user_context, options);
|
||||
|
|
@ -89,11 +96,55 @@ fn runWatchMode(allocator: std.mem.Allocator, config: Config, user_context: auth
|
|||
}
|
||||
}
|
||||
|
||||
fn runTuiMode(allocator: std.mem.Allocator, config: Config, args: []const []const u8) !void {
|
||||
if (config.isLocalMode()) {
|
||||
core.output.errorMsg("status", "TUI mode requires server mode. Use 'ml status' without --tui for local mode.");
|
||||
return error.ServerOnlyFeature;
|
||||
}
|
||||
|
||||
std.debug.print("Launching TUI via SSH...\n", .{});
|
||||
|
||||
// Build remote command that exports config via env vars and runs the TUI
|
||||
var remote_cmd_buffer = std.ArrayList(u8){};
|
||||
defer remote_cmd_buffer.deinit(allocator);
|
||||
{
|
||||
const writer = remote_cmd_buffer.writer(allocator);
|
||||
try writer.print("cd {s} && ", .{config.worker_base});
|
||||
try writer.print(
|
||||
"FETCH_ML_CLI_HOST=\"{s}\" FETCH_ML_CLI_USER=\"{s}\" FETCH_ML_CLI_BASE=\"{s}\" ",
|
||||
.{ config.worker_host, config.worker_user, config.worker_base },
|
||||
);
|
||||
try writer.print(
|
||||
"FETCH_ML_CLI_PORT=\"{d}\" FETCH_ML_CLI_API_KEY=\"{s}\" ",
|
||||
.{ config.worker_port, config.api_key },
|
||||
);
|
||||
try writer.writeAll("./bin/tui");
|
||||
for (args) |arg| {
|
||||
try writer.print(" {s}", .{arg});
|
||||
}
|
||||
}
|
||||
const remote_cmd = try remote_cmd_buffer.toOwnedSlice(allocator);
|
||||
defer allocator.free(remote_cmd);
|
||||
|
||||
// Execute SSH command to launch TUI
|
||||
const ssh_args = &[_][]const u8{
|
||||
"ssh",
|
||||
config.worker_user,
|
||||
config.worker_host,
|
||||
remote_cmd,
|
||||
};
|
||||
|
||||
var child = std.process.Child.init(ssh_args, allocator);
|
||||
_ = try child.spawn();
|
||||
_ = try child.wait();
|
||||
}
|
||||
|
||||
fn printUsage() !void {
|
||||
colors.printInfo("Usage: ml status [options]\n", .{});
|
||||
colors.printInfo("\nOptions:\n", .{});
|
||||
colors.printInfo(" --json Output structured JSON\n", .{});
|
||||
colors.printInfo(" --watch Watch mode - continuously update status\n", .{});
|
||||
colors.printInfo(" --tui Launch TUI monitor via SSH\n", .{});
|
||||
colors.printInfo(" --limit <count> Limit number of results shown\n", .{});
|
||||
colors.printInfo(" --watch-interval=<s> Set watch interval in seconds (default: 5)\n", .{});
|
||||
colors.printInfo(" --help Show this help message\n", .{});
|
||||
|
|
|
|||
|
|
@ -1,177 +1,285 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const rsync = @import("../utils/rsync_embedded.zig");
|
||||
const config = @import("../config.zig");
|
||||
const db = @import("../db.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const logging = @import("../utils/logging.zig");
|
||||
const json = @import("../utils/json.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const core = @import("../core.zig");
|
||||
const manifest_lib = @import("../manifest.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var specific_run_id: ?[]const u8 = null;
|
||||
|
||||
// Global flags
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsage();
|
||||
return printUsage();
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
} else if (!std.mem.startsWith(u8, arg, "--")) {
|
||||
specific_run_id = arg;
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml sync requires server connection\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
var runs_to_sync: std.ArrayList(RunInfo) = .empty;
|
||||
defer {
|
||||
for (runs_to_sync.items) |*r| r.deinit(allocator);
|
||||
runs_to_sync.deinit(allocator);
|
||||
}
|
||||
|
||||
if (specific_run_id) |run_id| {
|
||||
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE run_id = ? AND synced = 0;";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
if (try db.DB.step(stmt)) {
|
||||
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
|
||||
} else {
|
||||
colors.printWarning("Run {s} already synced or not found\n", .{run_id});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const path = args[0];
|
||||
var job_name: ?[]const u8 = null;
|
||||
var should_queue = false;
|
||||
var priority: u8 = 5;
|
||||
var json_mode: bool = false;
|
||||
|
||||
// Parse flags
|
||||
var i: usize = 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
|
||||
job_name = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--queue")) {
|
||||
should_queue = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--json")) {
|
||||
json_mode = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
|
||||
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
}
|
||||
|
||||
// Calculate commit ID (SHA256 of directory tree)
|
||||
const commit_id = try crypto.hashDirectory(allocator, path);
|
||||
defer allocator.free(commit_id);
|
||||
|
||||
// Construct remote destination path
|
||||
const remote_path = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/{s}/files/",
|
||||
.{ config.api_key, config.worker_host, config.worker_base, commit_id },
|
||||
);
|
||||
defer allocator.free(remote_path);
|
||||
|
||||
// Sync using embedded rsync (no external binary needed)
|
||||
try rsync.sync(allocator, path, remote_path, config.worker_port);
|
||||
|
||||
if (json_mode) {
|
||||
std.debug.print("{\"ok\":true,\"action\":\"sync\",\"commit_id\":\"{s}\"}\n", .{commit_id});
|
||||
} else {
|
||||
colors.printSuccess("✓ Files synced to server\n", .{});
|
||||
}
|
||||
|
||||
// If queue flag is set, queue the job
|
||||
if (should_queue) {
|
||||
const queue_cmd = @import("queue.zig");
|
||||
const actual_job_name = job_name orelse commit_id[0..8];
|
||||
const queue_args = [_][]const u8{ actual_job_name, "--commit", commit_id, "--priority", try std.fmt.allocPrint(allocator, "{d}", .{priority}) };
|
||||
defer allocator.free(queue_args[queue_args.len - 1]);
|
||||
try queue_cmd.run(allocator, &queue_args);
|
||||
}
|
||||
|
||||
// Optional: Connect to server for progress monitoring if --monitor flag is used
|
||||
var monitor_progress = false;
|
||||
for (args[1..]) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--monitor")) {
|
||||
monitor_progress = true;
|
||||
break;
|
||||
const sql = "SELECT run_id, experiment_id, name, status, start_time, end_time FROM ml_runs WHERE synced = 0;";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
while (try db.DB.step(stmt)) {
|
||||
try runs_to_sync.append(allocator, try RunInfo.fromStmt(stmt, allocator));
|
||||
}
|
||||
}
|
||||
|
||||
if (monitor_progress) {
|
||||
std.debug.print("\nMonitoring sync progress...\n", .{});
|
||||
try monitorSyncProgress(allocator, &config, commit_id);
|
||||
if (runs_to_sync.items.len == 0) {
|
||||
if (!flags.json) colors.printSuccess("All runs already synced!\n", .{});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fn printUsage() void {
|
||||
logging.err("Usage: ml sync <path> [options]\n\n", .{});
|
||||
logging.err("Options:\n", .{});
|
||||
logging.err(" --name <job> Override job name when used with --queue\n", .{});
|
||||
logging.err(" --queue Queue the job after syncing\n", .{});
|
||||
logging.err(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
|
||||
logging.err(" --monitor Wait and show basic sync progress\n", .{});
|
||||
logging.err(" --json Output machine-readable JSON (sync result only)\n", .{});
|
||||
logging.err(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
fn monitorSyncProgress(allocator: std.mem.Allocator, config: *const Config, commit_id: []const u8) !void {
|
||||
_ = commit_id;
|
||||
// Use plain password for WebSocket authentication
|
||||
const api_key_plain = config.api_key;
|
||||
|
||||
// Connect to server with retry logic
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
logging.info("Connecting to server {s}...\n", .{ws_url});
|
||||
var client = try ws.Client.connectWithRetry(allocator, ws_url, api_key_plain, 3);
|
||||
defer client.disconnect();
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
// Send progress monitoring request (this would be a new opcode on the server side)
|
||||
// For now, we'll just listen for any progress messages
|
||||
|
||||
var timeout_counter: u32 = 0;
|
||||
const max_timeout = 30; // 30 seconds timeout
|
||||
var spinner_index: usize = 0;
|
||||
const spinner_chars = [_]u8{ '|', '/', '-', '\\' };
|
||||
|
||||
while (timeout_counter < max_timeout) {
|
||||
const message = client.receiveMessage(allocator) catch |err| {
|
||||
switch (err) {
|
||||
error.ConnectionClosed, error.ConnectionTimedOut => {
|
||||
timeout_counter += 1;
|
||||
spinner_index = (spinner_index + 1) % 4;
|
||||
logging.progress("Waiting for progress {c} (attempt {d}/{d})\n", .{ spinner_chars[spinner_index], timeout_counter, max_timeout });
|
||||
std.Thread.sleep(1 * std.time.ns_per_s);
|
||||
continue;
|
||||
},
|
||||
else => return err,
|
||||
}
|
||||
var success_count: usize = 0;
|
||||
for (runs_to_sync.items) |run_info| {
|
||||
if (!flags.json) colors.printInfo("Syncing run {s}...\n", .{run_info.run_id[0..8]});
|
||||
syncRun(allocator, &database, &client, run_info, api_key_hash) catch |err| {
|
||||
if (!flags.json) colors.printError("Failed to sync run {s}: {}\n", .{ run_info.run_id[0..8], err });
|
||||
continue;
|
||||
};
|
||||
defer allocator.free(message);
|
||||
|
||||
// Parse JSON progress message using shared utilities
|
||||
const parsed = std.json.parseFromSlice(std.json.Value, allocator, message, .{}) catch {
|
||||
logging.success("Sync progress: {s}\n", .{message});
|
||||
break;
|
||||
};
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value == .object) {
|
||||
const root = parsed.value.object;
|
||||
const status = json.getString(root, "status") orelse "unknown";
|
||||
const progress = json.getInt(root, "progress") orelse 0;
|
||||
const total = json.getInt(root, "total") orelse 0;
|
||||
|
||||
if (std.mem.eql(u8, status, "complete")) {
|
||||
logging.success("Sync complete!\n", .{});
|
||||
break;
|
||||
} else if (std.mem.eql(u8, status, "error")) {
|
||||
const error_msg = json.getString(root, "error") orelse "Unknown error";
|
||||
logging.err("Sync failed: {s}\n", .{error_msg});
|
||||
return error.SyncFailed;
|
||||
} else {
|
||||
const pct = if (total > 0) @divTrunc(progress * 100, total) else 0;
|
||||
logging.progress("Sync: {s} ({d}/{d} files, {d}%)\n", .{ status, progress, total, pct });
|
||||
}
|
||||
} else {
|
||||
logging.success("Sync progress: {s}\n", .{message});
|
||||
break;
|
||||
}
|
||||
const update_sql = "UPDATE ml_runs SET synced = 1 WHERE run_id = ?;";
|
||||
const update_stmt = try database.prepare(update_sql);
|
||||
defer db.DB.finalize(update_stmt);
|
||||
try db.DB.bindText(update_stmt, 1, run_info.run_id);
|
||||
_ = try db.DB.step(update_stmt);
|
||||
success_count += 1;
|
||||
}
|
||||
|
||||
if (timeout_counter >= max_timeout) {
|
||||
std.debug.print("Progress monitoring timed out. Sync may still be running.\n", .{});
|
||||
database.checkpointOnExit();
|
||||
|
||||
if (flags.json) {
|
||||
std.debug.print("{{\"success\":true,\"synced\":{d},\"total\":{d}}}\n", .{ success_count, runs_to_sync.items.len });
|
||||
} else {
|
||||
colors.printSuccess("Synced {d}/{d} runs\n", .{ success_count, runs_to_sync.items.len });
|
||||
}
|
||||
}
|
||||
|
||||
const RunInfo = struct {
|
||||
run_id: []const u8,
|
||||
experiment_id: []const u8,
|
||||
name: []const u8,
|
||||
status: []const u8,
|
||||
start_time: []const u8,
|
||||
end_time: ?[]const u8,
|
||||
|
||||
fn fromStmt(stmt: db.Stmt, allocator: std.mem.Allocator) !RunInfo {
|
||||
const s = stmt.?;
|
||||
return RunInfo{
|
||||
.run_id = try allocator.dupe(u8, db.DB.columnText(s, 0)),
|
||||
.experiment_id = try allocator.dupe(u8, db.DB.columnText(s, 1)),
|
||||
.name = try allocator.dupe(u8, db.DB.columnText(s, 2)),
|
||||
.status = try allocator.dupe(u8, db.DB.columnText(s, 3)),
|
||||
.start_time = try allocator.dupe(u8, db.DB.columnText(s, 4)),
|
||||
.end_time = if (db.DB.columnText(s, 5).len > 0) try allocator.dupe(u8, db.DB.columnText(s, 5)) else null,
|
||||
};
|
||||
}
|
||||
|
||||
fn deinit(self: *RunInfo, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.run_id);
|
||||
allocator.free(self.experiment_id);
|
||||
allocator.free(self.name);
|
||||
allocator.free(self.status);
|
||||
allocator.free(self.start_time);
|
||||
if (self.end_time) |et| allocator.free(et);
|
||||
}
|
||||
};
|
||||
|
||||
fn syncRun(
|
||||
allocator: std.mem.Allocator,
|
||||
database: *db.DB,
|
||||
client: *ws.Client,
|
||||
run_info: RunInfo,
|
||||
api_key_hash: []const u8,
|
||||
) !void {
|
||||
// Get metrics for this run
|
||||
var metrics: std.ArrayList(Metric) = .empty;
|
||||
defer {
|
||||
for (metrics.items) |*m| m.deinit(allocator);
|
||||
metrics.deinit(allocator);
|
||||
}
|
||||
|
||||
const metrics_sql = "SELECT key, value, step FROM ml_metrics WHERE run_id = ?;";
|
||||
const metrics_stmt = try database.prepare(metrics_sql);
|
||||
defer db.DB.finalize(metrics_stmt);
|
||||
try db.DB.bindText(metrics_stmt, 1, run_info.run_id);
|
||||
|
||||
while (try db.DB.step(metrics_stmt)) {
|
||||
try metrics.append(allocator, Metric{
|
||||
.key = try allocator.dupe(u8, db.DB.columnText(metrics_stmt, 0)),
|
||||
.value = db.DB.columnDouble(metrics_stmt, 1),
|
||||
.step = db.DB.columnInt64(metrics_stmt, 2),
|
||||
});
|
||||
}
|
||||
|
||||
// Get params for this run
|
||||
var params: std.ArrayList(Param) = .empty;
|
||||
defer {
|
||||
for (params.items) |*p| p.deinit(allocator);
|
||||
params.deinit(allocator);
|
||||
}
|
||||
|
||||
const params_sql = "SELECT key, value FROM ml_params WHERE run_id = ?;";
|
||||
const params_stmt = try database.prepare(params_sql);
|
||||
defer db.DB.finalize(params_stmt);
|
||||
try db.DB.bindText(params_stmt, 1, run_info.run_id);
|
||||
|
||||
while (try db.DB.step(params_stmt)) {
|
||||
try params.append(allocator, Param{
|
||||
.key = try allocator.dupe(u8, db.DB.columnText(params_stmt, 0)),
|
||||
.value = try allocator.dupe(u8, db.DB.columnText(params_stmt, 1)),
|
||||
});
|
||||
}
|
||||
|
||||
// Build sync JSON
|
||||
var sync_json: std.ArrayList(u8) = .empty;
|
||||
defer sync_json.deinit(allocator);
|
||||
const writer = sync_json.writer(allocator);
|
||||
|
||||
try writer.writeAll("{");
|
||||
try writer.print("\"run_id\":\"{s}\",", .{run_info.run_id});
|
||||
try writer.print("\"experiment_id\":\"{s}\",", .{run_info.experiment_id});
|
||||
try writer.print("\"name\":\"{s}\",", .{run_info.name});
|
||||
try writer.print("\"status\":\"{s}\",", .{run_info.status});
|
||||
try writer.print("\"start_time\":\"{s}\",", .{run_info.start_time});
|
||||
if (run_info.end_time) |et| {
|
||||
try writer.print("\"end_time\":\"{s}\",", .{et});
|
||||
} else {
|
||||
try writer.writeAll("\"end_time\":null,");
|
||||
}
|
||||
|
||||
// Add metrics
|
||||
try writer.writeAll("\"metrics\":[");
|
||||
for (metrics.items, 0..) |m, i| {
|
||||
if (i > 0) try writer.writeAll(",");
|
||||
try writer.print("{{\"key\":\"{s}\",\"value\":{d},\"step\":{d}}}", .{ m.key, m.value, m.step });
|
||||
}
|
||||
try writer.writeAll("],");
|
||||
|
||||
// Add params
|
||||
try writer.writeAll("\"params\":[");
|
||||
for (params.items, 0..) |p, i| {
|
||||
if (i > 0) try writer.writeAll(",");
|
||||
try writer.print("{{\"key\":\"{s}\",\"value\":\"{s}\"}}", .{ p.key, p.value });
|
||||
}
|
||||
try writer.writeAll("]}");
|
||||
|
||||
// Send sync_run message
|
||||
try client.sendSyncRun(sync_json.items, api_key_hash);
|
||||
|
||||
// Wait for sync_ack
|
||||
const response = try client.receiveMessage(allocator);
|
||||
defer allocator.free(response);
|
||||
|
||||
if (std.mem.indexOf(u8, response, "sync_ack") == null) {
|
||||
return error.SyncRejected;
|
||||
}
|
||||
}
|
||||
|
||||
const Metric = struct {
|
||||
key: []const u8,
|
||||
value: f64,
|
||||
step: i64,
|
||||
|
||||
fn deinit(self: *Metric, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.key);
|
||||
}
|
||||
};
|
||||
|
||||
const Param = struct {
|
||||
key: []const u8,
|
||||
value: []const u8,
|
||||
|
||||
fn deinit(self: *Param, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.key);
|
||||
allocator.free(self.value);
|
||||
}
|
||||
};
|
||||
|
||||
fn printUsage() void {
|
||||
std.debug.print("Usage: ml sync [run_id] [options]\n\n", .{});
|
||||
std.debug.print("Push local experiment runs to the server.\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n", .{});
|
||||
std.debug.print(" --help, -h Show this help message\n\n", .{});
|
||||
std.debug.print("Examples:\n", .{});
|
||||
std.debug.print(" ml sync # Sync all unsynced runs\n", .{});
|
||||
std.debug.print(" ml sync abc123 # Sync specific run\n", .{});
|
||||
}
|
||||
|
||||
/// Find the git root directory by walking up from the given path
|
||||
fn findGitRoot(allocator: std.mem.Allocator, start_path: []const u8) !?[]const u8 {
|
||||
var buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
const path = try std.fs.realpath(start_path, &buf);
|
||||
|
||||
var current = path;
|
||||
while (true) {
|
||||
// Check if .git exists in current directory
|
||||
const git_path = try std.fs.path.join(allocator, &[_][]const u8{ current, ".git" });
|
||||
defer allocator.free(git_path);
|
||||
|
||||
if (std.fs.accessAbsolute(git_path, .{})) {
|
||||
// Found .git directory
|
||||
return try allocator.dupe(u8, current);
|
||||
} else |_| {
|
||||
// .git not found here, try parent
|
||||
const parent = std.fs.path.dirname(current);
|
||||
if (parent == null or std.mem.eql(u8, parent.?, current)) {
|
||||
// Reached root without finding .git
|
||||
return null;
|
||||
}
|
||||
current = parent.?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const protocol = @import("../net/protocol.zig");
|
|||
const colors = @import("../utils/colors.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const io = @import("../utils/io.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub const Options = struct {
|
||||
json: bool = false,
|
||||
|
|
@ -14,36 +15,32 @@ pub const Options = struct {
|
|||
};
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len == 0) {
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
var opts = Options{};
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var commit_hex: ?[]const u8 = null;
|
||||
var task_id: ?[]const u8 = null;
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
opts.json = true;
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--verbose")) {
|
||||
opts.verbose = true;
|
||||
flags.verbose = true;
|
||||
} else if (std.mem.eql(u8, arg, "--task") and i + 1 < args.len) {
|
||||
opts.task_id = args[i + 1];
|
||||
task_id = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.startsWith(u8, arg, "--help")) {
|
||||
try printUsage();
|
||||
return;
|
||||
return printUsage();
|
||||
} else if (std.mem.startsWith(u8, arg, "--")) {
|
||||
colors.printError("Unknown option: {s}\n", .{arg});
|
||||
try printUsage();
|
||||
core.output.errorMsg("validate", "Unknown option");
|
||||
return error.InvalidArgs;
|
||||
} else {
|
||||
commit_hex = arg;
|
||||
}
|
||||
}
|
||||
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
|
|
@ -61,10 +58,13 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
if (opts.task_id) |tid| {
|
||||
if (task_id) |tid| {
|
||||
try client.sendValidateRequestTask(api_key_hash, tid);
|
||||
} else {
|
||||
if (commit_hex == null or commit_hex.?.len != 40) {
|
||||
if (commit_hex == null) {
|
||||
core.output.errorMsg("validate", "No commit hash specified");
|
||||
return printUsage();
|
||||
} else if (commit_hex.?.len != 40) {
|
||||
colors.printError("validate requires a 40-char commit id (or --task <task_id>)\n", .{});
|
||||
try printUsage();
|
||||
return error.InvalidArgs;
|
||||
|
|
@ -80,7 +80,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
defer allocator.free(msg);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(msg, allocator) catch {
|
||||
if (opts.json) {
|
||||
if (flags.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{msg});
|
||||
} else {
|
||||
|
|
@ -101,7 +101,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
}
|
||||
|
||||
const payload = packet.data_payload.?;
|
||||
if (opts.json) {
|
||||
if (flags.json) {
|
||||
var out = io.stdoutWriter();
|
||||
try out.print("{s}\n", .{payload});
|
||||
} else {
|
||||
|
|
@ -109,7 +109,7 @@ pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
|||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value.object;
|
||||
const ok = try printHumanReport(root, opts.verbose);
|
||||
const ok = try printHumanReport(root, flags.verbose);
|
||||
if (!ok) return error.ValidationFailed;
|
||||
}
|
||||
}
|
||||
|
|
@ -224,8 +224,8 @@ fn printUsage() !void {
|
|||
|
||||
test "validate human report formatting" {
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
const allocator = gpa.allocator();
|
||||
defer _ = gpa.deinit();
|
||||
|
||||
const payload =
|
||||
\\{
|
||||
|
|
@ -245,8 +245,8 @@ test "validate human report formatting" {
|
|||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, payload, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
var buf = std.ArrayList(u8).init(allocator);
|
||||
defer buf.deinit();
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
_ = try printHumanReport(buf.writer(), parsed.value.object, false);
|
||||
try testing.expect(std.mem.indexOf(u8, buf.items, "failed_checks") != null);
|
||||
|
|
|
|||
|
|
@ -1,110 +1,83 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("../config.zig").Config;
|
||||
const config = @import("../config.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const rsync = @import("../utils/rsync_embedded.zig");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const core = @import("../core.zig");
|
||||
const mode = @import("../mode.zig");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
|
||||
pub fn run(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
var flags = core.flags.CommonFlags{};
|
||||
var should_sync = false;
|
||||
const sync_interval: u64 = 30; // Default 30 seconds
|
||||
|
||||
if (args.len == 0) {
|
||||
printUsage();
|
||||
return error.InvalidArgs;
|
||||
return printUsage();
|
||||
}
|
||||
|
||||
// Global flags
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
printUsage();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const path = args[0];
|
||||
var job_name: ?[]const u8 = null;
|
||||
var priority: u8 = 5;
|
||||
var should_queue = false;
|
||||
var json: bool = false;
|
||||
|
||||
// Parse flags
|
||||
var i: usize = 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
if (std.mem.eql(u8, args[i], "--name") and i + 1 < args.len) {
|
||||
job_name = args[i + 1];
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--priority") and i + 1 < args.len) {
|
||||
priority = try std.fmt.parseInt(u8, args[i + 1], 10);
|
||||
i += 1;
|
||||
} else if (std.mem.eql(u8, args[i], "--queue")) {
|
||||
should_queue = true;
|
||||
} else if (std.mem.eql(u8, args[i], "--json")) {
|
||||
json = true;
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
return printUsage();
|
||||
} else if (std.mem.eql(u8, arg, "--sync")) {
|
||||
should_sync = true;
|
||||
} else if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
}
|
||||
}
|
||||
|
||||
const config = try Config.load(allocator);
|
||||
core.output.init(if (flags.json) .json else .text);
|
||||
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_config = config;
|
||||
mut_config.deinit(allocator);
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{\"ok\":true,\"action\":\"watch\",\"path\":\"{s}\",\"queued\":{s}}\n", .{ path, if (should_queue) "true" else "false" });
|
||||
} else {
|
||||
std.debug.print("Watching {s} for changes...\n", .{path});
|
||||
std.debug.print("Press Ctrl+C to stop\n", .{});
|
||||
}
|
||||
|
||||
// Initial sync
|
||||
var last_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config);
|
||||
defer allocator.free(last_commit_id);
|
||||
|
||||
// Watch for changes
|
||||
var watcher = try std.fs.cwd().openDir(path, .{ .iterate = true });
|
||||
defer watcher.close();
|
||||
|
||||
var last_modified: u64 = 0;
|
||||
|
||||
while (true) {
|
||||
// Check for file changes
|
||||
var modified = false;
|
||||
var walker = try watcher.walk(allocator);
|
||||
defer walker.deinit();
|
||||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind == .file) {
|
||||
const file = try watcher.openFile(entry.path, .{});
|
||||
defer file.close();
|
||||
|
||||
const stat = try file.stat();
|
||||
if (stat.mtime > last_modified) {
|
||||
last_modified = @intCast(stat.mtime);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
// Check mode if syncing
|
||||
if (should_sync) {
|
||||
const mode_result = try mode.detect(allocator, cfg);
|
||||
if (mode.isOffline(mode_result.mode)) {
|
||||
colors.printError("ml watch --sync requires server connection\n", .{});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
if (!json) {
|
||||
std.debug.print("\nChanges detected, syncing...\n", .{});
|
||||
}
|
||||
if (flags.json) {
|
||||
std.debug.print("{{\"ok\":true,\"action\":\"watch\",\"sync\":{s}}}\n", .{if (should_sync) "true" else "false"});
|
||||
} else {
|
||||
if (should_sync) {
|
||||
colors.printInfo("Watching for changes with auto-sync every {d}s...\n", .{sync_interval});
|
||||
} else {
|
||||
colors.printInfo("Watching directory for changes...\n", .{});
|
||||
}
|
||||
colors.printInfo("Press Ctrl+C to stop\n", .{});
|
||||
}
|
||||
|
||||
const new_commit_id = try syncAndQueue(allocator, path, job_name, priority, should_queue, config);
|
||||
defer allocator.free(new_commit_id);
|
||||
|
||||
if (!std.mem.eql(u8, last_commit_id, new_commit_id)) {
|
||||
allocator.free(last_commit_id);
|
||||
last_commit_id = try allocator.dupe(u8, new_commit_id);
|
||||
if (!json) {
|
||||
std.debug.print("✓ Synced new version: {s}\n", .{last_commit_id[0..8]});
|
||||
}
|
||||
// Watch loop
|
||||
var last_synced: i64 = 0;
|
||||
while (true) {
|
||||
if (should_sync) {
|
||||
const now = std.time.timestamp();
|
||||
if (now - last_synced >= @as(i64, @intCast(sync_interval))) {
|
||||
// Trigger sync
|
||||
const sync_cmd = @import("sync.zig");
|
||||
sync_cmd.run(allocator, &[_][]const u8{"--json"}) catch |err| {
|
||||
if (!flags.json) {
|
||||
colors.printError("Auto-sync failed: {}\n", .{err});
|
||||
}
|
||||
};
|
||||
last_synced = now;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait before checking again
|
||||
std.Thread.sleep(2_000_000_000); // 2 seconds in nanoseconds
|
||||
std.Thread.sleep(2_000_000_000); // 2 seconds
|
||||
}
|
||||
}
|
||||
|
||||
fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, config: Config) ![]u8 {
|
||||
fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]const u8, priority: u8, should_queue: bool, cfg: config.Config) ![]u8 {
|
||||
// Calculate commit ID
|
||||
const commit_id = try crypto.hashDirectory(allocator, path);
|
||||
|
||||
|
|
@ -112,22 +85,22 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con
|
|||
const remote_path = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"{s}@{s}:{s}/{s}/files/",
|
||||
.{ config.worker_user, config.worker_host, config.worker_base, commit_id },
|
||||
.{ cfg.worker_user, cfg.worker_host, cfg.worker_base, commit_id },
|
||||
);
|
||||
defer allocator.free(remote_path);
|
||||
|
||||
try rsync.sync(allocator, path, remote_path, config.worker_port);
|
||||
try rsync.sync(allocator, path, remote_path, cfg.worker_port);
|
||||
|
||||
if (should_queue) {
|
||||
const actual_job_name = job_name orelse commit_id[0..8];
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, config.api_key);
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
// Connect to WebSocket and queue job
|
||||
const ws_url = try config.getWebSocketUrl(allocator);
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, config.api_key);
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendQueueJob(actual_job_name, commit_id, priority, api_key_hash);
|
||||
|
|
@ -144,11 +117,10 @@ fn syncAndQueue(allocator: std.mem.Allocator, path: []const u8, job_name: ?[]con
|
|||
}
|
||||
|
||||
fn printUsage() void {
|
||||
std.debug.print("Usage: ml watch <path> [options]\n\n", .{});
|
||||
std.debug.print("Usage: ml watch [options]\n\n", .{});
|
||||
std.debug.print("Watch for changes and optionally auto-sync.\n\n", .{});
|
||||
std.debug.print("Options:\n", .{});
|
||||
std.debug.print(" --name <job> Override job name when used with --queue\n", .{});
|
||||
std.debug.print(" --priority <N> Priority to use when queueing (default: 5)\n", .{});
|
||||
std.debug.print(" --queue Queue on every sync\n", .{});
|
||||
std.debug.print(" --json Emit a single JSON line describing watch start\n", .{});
|
||||
std.debug.print(" --sync Auto-sync runs to server every 30s\n", .{});
|
||||
std.debug.print(" --json Output structured JSON\n", .{});
|
||||
std.debug.print(" --help, -h Show this help message\n", .{});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,26 @@
|
|||
const std = @import("std");
|
||||
const security = @import("security.zig");
|
||||
|
||||
pub const ExperimentConfig = struct {
|
||||
name: []const u8,
|
||||
entrypoint: []const u8,
|
||||
};
|
||||
|
||||
/// URI-based configuration for FetchML
|
||||
/// Supports: sqlite:///path/to.db or wss://server.com/ws
|
||||
pub const Config = struct {
|
||||
// Primary storage URI for local mode
|
||||
tracking_uri: []const u8,
|
||||
// Artifacts directory (for local storage)
|
||||
artifact_path: []const u8,
|
||||
// Sync target URI (for pushing local runs to server)
|
||||
sync_uri: []const u8,
|
||||
// Force local mode regardless of server config
|
||||
force_local: bool,
|
||||
// Experiment configuration ([experiment] section)
|
||||
experiment: ?ExperimentConfig,
|
||||
|
||||
// Legacy server config (for runner mode)
|
||||
worker_host: []const u8,
|
||||
worker_user: []const u8,
|
||||
worker_base: []const u8,
|
||||
|
|
@ -20,79 +39,131 @@ pub const Config = struct {
|
|||
default_json: bool,
|
||||
default_priority: u8,
|
||||
|
||||
/// Check if this is local mode (sqlite://) or runner mode (wss://)
|
||||
pub fn isLocalMode(self: Config) bool {
|
||||
return std.mem.startsWith(u8, self.tracking_uri, "sqlite://");
|
||||
}
|
||||
|
||||
/// Get the database path from tracking_uri (removes sqlite:// prefix)
|
||||
pub fn getDBPath(self: Config, allocator: std.mem.Allocator) ![]const u8 {
|
||||
const prefix = "sqlite://";
|
||||
if (!std.mem.startsWith(u8, self.tracking_uri, prefix)) {
|
||||
return error.InvalidTrackingURI;
|
||||
}
|
||||
|
||||
const path = self.tracking_uri[prefix.len..];
|
||||
|
||||
// Handle ~ expansion for home directory
|
||||
if (path.len > 0 and path[0] == '~') {
|
||||
const home = std.posix.getenv("HOME") orelse return error.NoHomeDir;
|
||||
return std.fmt.allocPrint(allocator, "{s}{s}", .{ home, path[1..] });
|
||||
}
|
||||
|
||||
return allocator.dupe(u8, path);
|
||||
}
|
||||
|
||||
pub fn validate(self: Config) !void {
|
||||
// Validate host
|
||||
if (self.worker_host.len == 0) {
|
||||
return error.EmptyHost;
|
||||
}
|
||||
// Only validate server config if not in local mode
|
||||
if (!self.isLocalMode()) {
|
||||
// Validate host
|
||||
if (self.worker_host.len == 0) {
|
||||
return error.EmptyHost;
|
||||
}
|
||||
|
||||
// Validate port range
|
||||
if (self.worker_port == 0 or self.worker_port > 65535) {
|
||||
return error.InvalidPort;
|
||||
}
|
||||
// Validate port range
|
||||
if (self.worker_port == 0 or self.worker_port > 65535) {
|
||||
return error.InvalidPort;
|
||||
}
|
||||
|
||||
// Validate API key presence
|
||||
if (self.api_key.len == 0) {
|
||||
return error.EmptyAPIKey;
|
||||
}
|
||||
// Validate API key presence
|
||||
if (self.api_key.len == 0) {
|
||||
return error.EmptyAPIKey;
|
||||
}
|
||||
|
||||
// Validate base path
|
||||
if (self.worker_base.len == 0) {
|
||||
return error.EmptyBasePath;
|
||||
// Validate base path
|
||||
if (self.worker_base.len == 0) {
|
||||
return error.EmptyBasePath;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(allocator: std.mem.Allocator) !Config {
|
||||
const home = std.posix.getenv("HOME") orelse return error.NoHomeDir;
|
||||
const config_path = try std.fmt.allocPrint(allocator, "{s}/.ml/config.toml", .{home});
|
||||
defer allocator.free(config_path);
|
||||
/// Load config with priority: CLI > Env > Project > Global > Default
|
||||
pub fn loadWithOverrides(allocator: std.mem.Allocator, cli_tracking_uri: ?[]const u8, cli_artifact_path: ?[]const u8, cli_sync_uri: ?[]const u8) !Config {
|
||||
// Start with defaults
|
||||
var config = try loadDefaults(allocator);
|
||||
|
||||
const file = std.fs.openFileAbsolute(config_path, .{}) catch |err| {
|
||||
if (err == error.FileNotFound) {
|
||||
std.debug.print("Config file not found. Run 'ml init' first.\n", .{});
|
||||
return error.ConfigNotFound;
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
// Load config with environment variable overrides
|
||||
var config = try loadFromFile(allocator, file);
|
||||
|
||||
// Apply environment variable overrides (FETCH_ML_CLI_* to match TUI)
|
||||
if (std.posix.getenv("FETCH_ML_CLI_HOST")) |host| {
|
||||
config.worker_host = try allocator.dupe(u8, host);
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_USER")) |user| {
|
||||
config.worker_user = try allocator.dupe(u8, user);
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_BASE")) |base| {
|
||||
config.worker_base = try allocator.dupe(u8, base);
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_PORT")) |port_str| {
|
||||
config.worker_port = try std.fmt.parseInt(u16, port_str, 10);
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_API_KEY")) |api_key| {
|
||||
config.api_key = try allocator.dupe(u8, api_key);
|
||||
// Priority 4: Apply global config if exists
|
||||
if (try loadGlobalConfig(allocator)) |global| {
|
||||
config.apply(global);
|
||||
config.deinitGlobal(allocator, global);
|
||||
}
|
||||
|
||||
// Try to get API key from keychain if not in config or env
|
||||
if (config.api_key.len == 0) {
|
||||
if (try security.SecureStorage.retrieveApiKey(allocator)) |keychain_key| {
|
||||
config.api_key = keychain_key;
|
||||
}
|
||||
// Priority 3: Apply project config if exists
|
||||
if (try loadProjectConfig(allocator)) |project| {
|
||||
config.apply(project);
|
||||
config.deinitGlobal(allocator, project);
|
||||
}
|
||||
|
||||
// Priority 2: Apply environment variables
|
||||
config.applyEnv(allocator);
|
||||
|
||||
// Priority 1: Apply CLI overrides
|
||||
if (cli_tracking_uri) |uri| {
|
||||
allocator.free(config.tracking_uri);
|
||||
config.tracking_uri = try allocator.dupe(u8, uri);
|
||||
}
|
||||
if (cli_artifact_path) |path| {
|
||||
allocator.free(config.artifact_path);
|
||||
config.artifact_path = try allocator.dupe(u8, path);
|
||||
}
|
||||
if (cli_sync_uri) |uri| {
|
||||
allocator.free(config.sync_uri);
|
||||
config.sync_uri = try allocator.dupe(u8, uri);
|
||||
}
|
||||
|
||||
try config.validate();
|
||||
return config;
|
||||
}
|
||||
|
||||
/// Legacy load function (no overrides)
|
||||
pub fn load(allocator: std.mem.Allocator) !Config {
|
||||
return loadWithOverrides(allocator, null, null, null);
|
||||
}
|
||||
|
||||
/// Load default configuration
|
||||
fn loadDefaults(allocator: std.mem.Allocator) !Config {
|
||||
return Config{
|
||||
.tracking_uri = try allocator.dupe(u8, "sqlite://./fetch_ml.db"),
|
||||
.artifact_path = try allocator.dupe(u8, "./experiments/"),
|
||||
.sync_uri = try allocator.dupe(u8, ""),
|
||||
.force_local = false,
|
||||
.experiment = null,
|
||||
.worker_host = try allocator.dupe(u8, ""),
|
||||
.worker_user = try allocator.dupe(u8, ""),
|
||||
.worker_base = try allocator.dupe(u8, ""),
|
||||
.worker_port = 22,
|
||||
.api_key = try allocator.dupe(u8, ""),
|
||||
.default_cpu = 2,
|
||||
.default_memory = 8,
|
||||
.default_gpu = 0,
|
||||
.default_gpu_memory = null,
|
||||
.default_dry_run = false,
|
||||
.default_validate = false,
|
||||
.default_json = false,
|
||||
.default_priority = 5,
|
||||
};
|
||||
}
|
||||
|
||||
fn loadFromFile(allocator: std.mem.Allocator, file: std.fs.File) !Config {
|
||||
const content = try file.readToEndAlloc(allocator, 1024 * 1024);
|
||||
defer allocator.free(content);
|
||||
|
||||
// Simple TOML parser - parse key=value pairs
|
||||
// Simple TOML parser - parse key=value pairs and [section] headers
|
||||
var config = Config{
|
||||
.tracking_uri = "",
|
||||
.artifact_path = "",
|
||||
.sync_uri = "",
|
||||
.force_local = false,
|
||||
.experiment = null,
|
||||
.worker_host = "",
|
||||
.worker_user = "",
|
||||
.worker_base = "",
|
||||
|
|
@ -108,11 +179,21 @@ pub const Config = struct {
|
|||
.default_priority = 5,
|
||||
};
|
||||
|
||||
var current_section: []const u8 = "root";
|
||||
var experiment_name: ?[]const u8 = null;
|
||||
var experiment_entrypoint: ?[]const u8 = null;
|
||||
|
||||
var lines = std.mem.splitScalar(u8, content, '\n');
|
||||
while (lines.next()) |line| {
|
||||
const trimmed = std.mem.trim(u8, line, " \t\r");
|
||||
if (trimmed.len == 0 or trimmed[0] == '#') continue;
|
||||
|
||||
// Check for section header [section]
|
||||
if (trimmed[0] == '[' and trimmed[trimmed.len - 1] == ']') {
|
||||
current_section = trimmed[1 .. trimmed.len - 1];
|
||||
continue;
|
||||
}
|
||||
|
||||
var parts = std.mem.splitScalar(u8, trimmed, '=');
|
||||
const key = std.mem.trim(u8, parts.next() orelse continue, " \t");
|
||||
const value_raw = std.mem.trim(u8, parts.next() orelse continue, " \t");
|
||||
|
|
@ -123,37 +204,67 @@ pub const Config = struct {
|
|||
else
|
||||
value_raw;
|
||||
|
||||
if (std.mem.eql(u8, key, "worker_host")) {
|
||||
config.worker_host = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "worker_user")) {
|
||||
config.worker_user = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "worker_base")) {
|
||||
config.worker_base = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "worker_port")) {
|
||||
config.worker_port = try std.fmt.parseInt(u16, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "api_key")) {
|
||||
config.api_key = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "default_cpu")) {
|
||||
config.default_cpu = try std.fmt.parseInt(u8, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "default_memory")) {
|
||||
config.default_memory = try std.fmt.parseInt(u8, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "default_gpu")) {
|
||||
config.default_gpu = try std.fmt.parseInt(u8, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "default_gpu_memory")) {
|
||||
if (value.len > 0) {
|
||||
config.default_gpu_memory = try allocator.dupe(u8, value);
|
||||
// Parse based on current section
|
||||
if (std.mem.eql(u8, current_section, "experiment")) {
|
||||
if (std.mem.eql(u8, key, "name")) {
|
||||
experiment_name = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "entrypoint")) {
|
||||
experiment_entrypoint = try allocator.dupe(u8, value);
|
||||
}
|
||||
} else {
|
||||
// Root level keys
|
||||
if (std.mem.eql(u8, key, "tracking_uri")) {
|
||||
config.tracking_uri = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "artifact_path")) {
|
||||
config.artifact_path = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "sync_uri")) {
|
||||
config.sync_uri = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "force_local")) {
|
||||
config.force_local = std.mem.eql(u8, value, "true");
|
||||
} else if (std.mem.eql(u8, key, "worker_host")) {
|
||||
config.worker_host = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "worker_user")) {
|
||||
config.worker_user = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "worker_base")) {
|
||||
config.worker_base = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "worker_port")) {
|
||||
config.worker_port = try std.fmt.parseInt(u16, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "api_key")) {
|
||||
config.api_key = try allocator.dupe(u8, value);
|
||||
} else if (std.mem.eql(u8, key, "default_cpu")) {
|
||||
config.default_cpu = try std.fmt.parseInt(u8, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "default_memory")) {
|
||||
config.default_memory = try std.fmt.parseInt(u8, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "default_gpu")) {
|
||||
config.default_gpu = try std.fmt.parseInt(u8, value, 10);
|
||||
} else if (std.mem.eql(u8, key, "default_gpu_memory")) {
|
||||
if (value.len > 0) {
|
||||
config.default_gpu_memory = try allocator.dupe(u8, value);
|
||||
}
|
||||
} else if (std.mem.eql(u8, key, "default_dry_run")) {
|
||||
config.default_dry_run = std.mem.eql(u8, value, "true");
|
||||
} else if (std.mem.eql(u8, key, "default_validate")) {
|
||||
config.default_validate = std.mem.eql(u8, value, "true");
|
||||
} else if (std.mem.eql(u8, key, "default_json")) {
|
||||
config.default_json = std.mem.eql(u8, value, "true");
|
||||
} else if (std.mem.eql(u8, key, "default_priority")) {
|
||||
config.default_priority = try std.fmt.parseInt(u8, value, 10);
|
||||
}
|
||||
} else if (std.mem.eql(u8, key, "default_dry_run")) {
|
||||
config.default_dry_run = std.mem.eql(u8, value, "true");
|
||||
} else if (std.mem.eql(u8, key, "default_validate")) {
|
||||
config.default_validate = std.mem.eql(u8, value, "true");
|
||||
} else if (std.mem.eql(u8, key, "default_json")) {
|
||||
config.default_json = std.mem.eql(u8, value, "true");
|
||||
} else if (std.mem.eql(u8, key, "default_priority")) {
|
||||
config.default_priority = try std.fmt.parseInt(u8, value, 10);
|
||||
}
|
||||
}
|
||||
|
||||
// Create experiment config if both name and entrypoint are set
|
||||
if (experiment_name != null and experiment_entrypoint != null) {
|
||||
config.experiment = ExperimentConfig{
|
||||
.name = experiment_name.?,
|
||||
.entrypoint = experiment_entrypoint.?,
|
||||
};
|
||||
} else if (experiment_name != null) {
|
||||
allocator.free(experiment_name.?);
|
||||
} else if (experiment_entrypoint != null) {
|
||||
allocator.free(experiment_entrypoint.?);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -174,27 +285,68 @@ pub const Config = struct {
|
|||
const file = try std.fs.createFileAbsolute(config_path, .{});
|
||||
defer file.close();
|
||||
|
||||
const writer = file.writer();
|
||||
try writer.print("worker_host = \"{s}\"\n", .{self.worker_host});
|
||||
try writer.print("worker_user = \"{s}\"\n", .{self.worker_user});
|
||||
try writer.print("worker_base = \"{s}\"\n", .{self.worker_base});
|
||||
try writer.print("worker_port = {d}\n", .{self.worker_port});
|
||||
try writer.print("api_key = \"{s}\"\n", .{self.api_key});
|
||||
try writer.print("\n# Default resource requests\n", .{});
|
||||
try writer.print("default_cpu = {d}\n", .{self.default_cpu});
|
||||
try writer.print("default_memory = {d}\n", .{self.default_memory});
|
||||
try writer.print("default_gpu = {d}\n", .{self.default_gpu});
|
||||
if (self.default_gpu_memory) |gpu_mem| {
|
||||
try writer.print("default_gpu_memory = \"{s}\"\n", .{gpu_mem});
|
||||
}
|
||||
try writer.print("\n# CLI behavior defaults\n", .{});
|
||||
try writer.print("default_dry_run = {s}\n", .{if (self.default_dry_run) "true" else "false"});
|
||||
try writer.print("default_validate = {s}\n", .{if (self.default_validate) "true" else "false"});
|
||||
try writer.print("default_json = {s}\n", .{if (self.default_json) "true" else "false"});
|
||||
try writer.print("default_priority = {d}\n", .{self.default_priority});
|
||||
// Write config directly using fmt.allocPrint and file.writeAll
|
||||
const content = try std.fmt.allocPrint(allocator,
|
||||
\\# FetchML Configuration
|
||||
\\tracking_uri = "{s}"
|
||||
\\artifact_path = "{s}"
|
||||
\\sync_uri = "{s}"
|
||||
\\force_local = {s}
|
||||
\\{s}
|
||||
\\# Server config (for runner mode)
|
||||
\\worker_host = "{s}"
|
||||
\\worker_user = "{s}"
|
||||
\\worker_base = "{s}"
|
||||
\\worker_port = {d}
|
||||
\\api_key = "{s}"
|
||||
\\
|
||||
\\# Default resource requests
|
||||
\\default_cpu = {d}
|
||||
\\default_memory = {d}
|
||||
\\default_gpu = {d}
|
||||
\\{s}
|
||||
\\# CLI behavior defaults
|
||||
\\default_dry_run = {s}
|
||||
\\default_validate = {s}
|
||||
\\default_json = {s}
|
||||
\\default_priority = {d}
|
||||
\\
|
||||
, .{
|
||||
self.tracking_uri,
|
||||
self.artifact_path,
|
||||
self.sync_uri,
|
||||
if (self.force_local) "true" else "false",
|
||||
if (self.experiment) |exp| try std.fmt.allocPrint(allocator,
|
||||
\\n[experiment]\nname = "{s}"\nentrypoint = "{s}"\n
|
||||
, .{ exp.name, exp.entrypoint }) else "",
|
||||
self.worker_host,
|
||||
self.worker_user,
|
||||
self.worker_base,
|
||||
self.worker_port,
|
||||
self.api_key,
|
||||
self.default_cpu,
|
||||
self.default_memory,
|
||||
self.default_gpu,
|
||||
if (self.default_gpu_memory) |gpu_mem| try std.fmt.allocPrint(allocator,
|
||||
\\default_gpu_memory = "{s}"\n
|
||||
, .{gpu_mem}) else "",
|
||||
if (self.default_dry_run) "true" else "false",
|
||||
if (self.default_validate) "true" else "false",
|
||||
if (self.default_json) "true" else "false",
|
||||
self.default_priority,
|
||||
});
|
||||
defer allocator.free(content);
|
||||
try file.writeAll(content);
|
||||
}
|
||||
|
||||
pub fn deinit(self: *Config, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.tracking_uri);
|
||||
allocator.free(self.artifact_path);
|
||||
allocator.free(self.sync_uri);
|
||||
if (self.experiment) |*exp| {
|
||||
allocator.free(exp.name);
|
||||
allocator.free(exp.entrypoint);
|
||||
}
|
||||
allocator.free(self.worker_host);
|
||||
allocator.free(self.worker_user);
|
||||
allocator.free(self.worker_base);
|
||||
|
|
@ -204,6 +356,101 @@ pub const Config = struct {
|
|||
}
|
||||
}
|
||||
|
||||
/// Apply settings from another config (for layering)
|
||||
fn apply(self: *Config, other: Config) void {
|
||||
if (other.tracking_uri.len > 0) {
|
||||
self.tracking_uri = other.tracking_uri;
|
||||
}
|
||||
if (other.artifact_path.len > 0) {
|
||||
self.artifact_path = other.artifact_path;
|
||||
}
|
||||
if (other.sync_uri.len > 0) {
|
||||
self.sync_uri = other.sync_uri;
|
||||
}
|
||||
if (other.force_local) {
|
||||
self.force_local = other.force_local;
|
||||
}
|
||||
if (other.experiment) |exp| {
|
||||
if (self.experiment == null) {
|
||||
self.experiment = exp;
|
||||
}
|
||||
}
|
||||
if (other.worker_host.len > 0) {
|
||||
self.worker_host = other.worker_host;
|
||||
}
|
||||
if (other.worker_user.len > 0) {
|
||||
self.worker_user = other.worker_user;
|
||||
}
|
||||
if (other.worker_base.len > 0) {
|
||||
self.worker_base = other.worker_base;
|
||||
}
|
||||
if (other.worker_port != 22) {
|
||||
self.worker_port = other.worker_port;
|
||||
}
|
||||
if (other.api_key.len > 0) {
|
||||
self.api_key = other.api_key;
|
||||
}
|
||||
}
|
||||
|
||||
/// Deinit a config that was loaded temporarily
|
||||
fn deinitGlobal(self: Config, allocator: std.mem.Allocator, other: Config) void {
|
||||
_ = self;
|
||||
allocator.free(other.tracking_uri);
|
||||
allocator.free(other.artifact_path);
|
||||
allocator.free(other.sync_uri);
|
||||
if (other.experiment) |*exp| {
|
||||
allocator.free(exp.name);
|
||||
allocator.free(exp.entrypoint);
|
||||
}
|
||||
allocator.free(other.worker_host);
|
||||
allocator.free(other.worker_user);
|
||||
allocator.free(other.worker_base);
|
||||
allocator.free(other.api_key);
|
||||
if (other.default_gpu_memory) |gpu_mem| {
|
||||
allocator.free(gpu_mem);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply environment variable overrides
|
||||
fn applyEnv(self: *Config, allocator: std.mem.Allocator) void {
|
||||
// FETCHML_* environment variables for URI-based config
|
||||
if (std.posix.getenv("FETCHML_TRACKING_URI")) |uri| {
|
||||
allocator.free(self.tracking_uri);
|
||||
self.tracking_uri = allocator.dupe(u8, uri) catch self.tracking_uri;
|
||||
}
|
||||
if (std.posix.getenv("FETCHML_ARTIFACT_PATH")) |path| {
|
||||
allocator.free(self.artifact_path);
|
||||
self.artifact_path = allocator.dupe(u8, path) catch self.artifact_path;
|
||||
}
|
||||
if (std.posix.getenv("FETCHML_SYNC_URI")) |uri| {
|
||||
allocator.free(self.sync_uri);
|
||||
self.sync_uri = allocator.dupe(u8, uri) catch self.sync_uri;
|
||||
}
|
||||
|
||||
// Legacy FETCH_ML_CLI_* variables
|
||||
if (std.posix.getenv("FETCH_ML_CLI_HOST")) |host| {
|
||||
allocator.free(self.worker_host);
|
||||
self.worker_host = allocator.dupe(u8, host) catch self.worker_host;
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_USER")) |user| {
|
||||
allocator.free(self.worker_user);
|
||||
self.worker_user = allocator.dupe(u8, user) catch self.worker_user;
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_BASE")) |base| {
|
||||
allocator.free(self.worker_base);
|
||||
self.worker_base = allocator.dupe(u8, base) catch self.worker_base;
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_PORT")) |port_str| {
|
||||
if (std.fmt.parseInt(u16, port_str, 10)) |port| {
|
||||
self.worker_port = port;
|
||||
} else |_| {}
|
||||
}
|
||||
if (std.posix.getenv("FETCH_ML_CLI_API_KEY")) |api_key| {
|
||||
allocator.free(self.api_key);
|
||||
self.api_key = allocator.dupe(u8, api_key) catch self.api_key;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get WebSocket URL for connecting to the server
|
||||
pub fn getWebSocketUrl(self: Config, allocator: std.mem.Allocator) ![]u8 {
|
||||
const protocol = if (self.worker_port == 443) "wss" else "ws";
|
||||
|
|
@ -212,3 +459,29 @@ pub const Config = struct {
|
|||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Load global config from ~/.ml/config.toml
|
||||
fn loadGlobalConfig(allocator: std.mem.Allocator) !?Config {
|
||||
const home = std.posix.getenv("HOME") orelse return null;
|
||||
const config_path = try std.fmt.allocPrint(allocator, "{s}/.ml/config.toml", .{home});
|
||||
defer allocator.free(config_path);
|
||||
|
||||
const file = std.fs.openFileAbsolute(config_path, .{ .lock = .none }) catch |err| {
|
||||
if (err == error.FileNotFound) return null;
|
||||
return err;
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
return try Config.loadFromFile(allocator, file);
|
||||
}
|
||||
|
||||
/// Load project config from .fetchml/config.toml in CWD
|
||||
fn loadProjectConfig(allocator: std.mem.Allocator) !?Config {
|
||||
const file = std.fs.openFileAbsolute(".fetchml/config.toml", .{ .lock = .none }) catch |err| {
|
||||
if (err == error.FileNotFound) return null;
|
||||
return err;
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
return try Config.loadFromFile(allocator, file);
|
||||
}
|
||||
|
|
|
|||
4
cli/src/core.zig
Normal file
4
cli/src/core.zig
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub const flags = @import("core/flags.zig");
|
||||
pub const output = @import("core/output.zig");
|
||||
pub const context = @import("core/context.zig");
|
||||
pub const experiment = @import("core/experiment_core.zig");
|
||||
132
cli/src/core/context.zig
Normal file
132
cli/src/core/context.zig
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
const std = @import("std");
|
||||
const config = @import("../config.zig");
|
||||
const output = @import("output.zig");
|
||||
|
||||
/// Execution mode for commands
|
||||
pub const Mode = enum {
|
||||
local,
|
||||
server,
|
||||
};
|
||||
|
||||
/// Execution context passed to all command handlers
|
||||
/// Provides unified access to allocator, config, and output mode
|
||||
pub const Context = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
mode: Mode,
|
||||
cfg: config.Config,
|
||||
json_output: bool,
|
||||
|
||||
/// Initialize context from config
|
||||
pub fn init(allocator: std.mem.Allocator, cfg: config.Config, json_output: bool) Context {
|
||||
const mode: Mode = if (cfg.isLocalMode()) .local else .server;
|
||||
return .{
|
||||
.allocator = allocator,
|
||||
.mode = mode,
|
||||
.cfg = cfg,
|
||||
.json_output = json_output,
|
||||
};
|
||||
}
|
||||
|
||||
/// Clean up context resources
|
||||
pub fn deinit(self: *Context) void {
|
||||
self.cfg.deinit(self.allocator);
|
||||
}
|
||||
|
||||
/// Check if running in local mode
|
||||
pub fn isLocal(self: Context) bool {
|
||||
return self.mode == .local;
|
||||
}
|
||||
|
||||
/// Check if running in server mode
|
||||
pub fn isServer(self: Context) bool {
|
||||
return self.mode == .server;
|
||||
}
|
||||
|
||||
/// Dispatch to appropriate implementation based on mode
|
||||
/// local_fn: function to call in local mode
|
||||
/// server_fn: function to call in server mode
|
||||
/// Both functions must have the same signature: fn (Context, []const []const u8) anyerror!void
|
||||
pub fn dispatch(
|
||||
self: Context,
|
||||
local_fn: *const fn (Context, []const []const u8) anyerror!void,
|
||||
server_fn: *const fn (Context, []const []const u8) anyerror!void,
|
||||
args: []const []const u8,
|
||||
) !void {
|
||||
switch (self.mode) {
|
||||
.local => return local_fn(self, args),
|
||||
.server => return server_fn(self, args),
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatch with result - returns a value
|
||||
pub fn dispatchWithResult(
|
||||
self: Context,
|
||||
local_fn: *const fn (Context, []const []const u8) anyerror![]const u8,
|
||||
server_fn: *const fn (Context, []const []const u8) anyerror![]const u8,
|
||||
args: []const []const u8,
|
||||
) ![]const u8 {
|
||||
switch (self.mode) {
|
||||
.local => return local_fn(self, args),
|
||||
.server => return server_fn(self, args),
|
||||
}
|
||||
}
|
||||
|
||||
/// Output helpers that respect context settings
|
||||
pub fn errorMsg(self: Context, comptime cmd: []const u8, message: []const u8) void {
|
||||
if (self.json_output) {
|
||||
output.errorMsg(cmd, message);
|
||||
} else {
|
||||
std.log.err("{s}: {s}", .{ cmd, message });
|
||||
}
|
||||
}
|
||||
|
||||
pub fn errorMsgDetailed(self: Context, comptime cmd: []const u8, message: []const u8, details: []const u8) void {
|
||||
if (self.json_output) {
|
||||
output.errorMsgDetailed(cmd, message, details);
|
||||
} else {
|
||||
std.log.err("{s}: {s} - {s}", .{ cmd, message, details });
|
||||
}
|
||||
}
|
||||
|
||||
pub fn success(self: Context, comptime cmd: []const u8) void {
|
||||
if (self.json_output) {
|
||||
output.success(cmd);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn successString(self: Context, comptime cmd: []const u8, comptime key: []const u8, value: []const u8) void {
|
||||
if (self.json_output) {
|
||||
output.successString(cmd, key, value);
|
||||
} else {
|
||||
std.debug.print("{s}: {s}\n", .{ key, value });
|
||||
}
|
||||
}
|
||||
|
||||
pub fn info(self: Context, comptime fmt: []const u8, args: anytype) void {
|
||||
if (!self.json_output) {
|
||||
std.debug.print(fmt ++ "\n", args);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn printUsage(_: Context, comptime cmd: []const u8, comptime usage: []const u8) void {
|
||||
output.usage(cmd, usage);
|
||||
}
|
||||
};
|
||||
|
||||
/// Require subcommand helper
|
||||
pub fn requireSubcommand(args: []const []const u8, comptime cmd_name: []const u8) ![]const u8 {
|
||||
if (args.len == 0) {
|
||||
std.log.err("Command '{s}' requires a subcommand", .{cmd_name});
|
||||
return error.MissingSubcommand;
|
||||
}
|
||||
return args[0];
|
||||
}
|
||||
|
||||
/// Match subcommand and return remaining args
|
||||
pub fn matchSubcommand(args: []const []const u8, comptime sub: []const u8) ?[]const []const u8 {
|
||||
if (args.len == 0) return null;
|
||||
if (std.mem.eql(u8, args[0], sub)) {
|
||||
return args[1..];
|
||||
}
|
||||
return null;
|
||||
}
|
||||
136
cli/src/core/experiment_core.zig
Normal file
136
cli/src/core/experiment_core.zig
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
//! Experiment core module - shared validation and formatting logic
|
||||
//!
|
||||
//! This module provides common utilities used by both local and server
|
||||
//! experiment operations, reducing code duplication between modes.
|
||||
|
||||
const std = @import("std");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
/// Experiment name validation
|
||||
pub fn validateExperimentName(name: []const u8) bool {
|
||||
if (name.len == 0 or name.len > 128) return false;
|
||||
|
||||
for (name) |c| {
|
||||
if (!std.ascii.isAlphanumeric(c) and c != '_' and c != '-' and c != '.') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Generate a UUID for experiment IDs
|
||||
pub fn generateExperimentId(allocator: std.mem.Allocator) ![]const u8 {
|
||||
// Simple UUID v4 format: xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx
|
||||
var buf: [36]u8 = undefined;
|
||||
const hex_chars = "0123456789abcdef";
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < 36) : (i += 1) {
|
||||
if (i == 8 or i == 13 or i == 18 or i == 23) {
|
||||
buf[i] = '-';
|
||||
} else if (i == 14) {
|
||||
buf[i] = '4'; // Version 4
|
||||
} else if (i == 19) {
|
||||
// Variant: 8, 9, a, or b
|
||||
const rand = std.crypto.random.int(u8);
|
||||
buf[i] = hex_chars[(rand & 0x03) + 8];
|
||||
} else {
|
||||
const rand = std.crypto.random.int(u8);
|
||||
buf[i] = hex_chars[rand & 0x0f];
|
||||
}
|
||||
}
|
||||
|
||||
return try allocator.dupe(u8, &buf);
|
||||
}
|
||||
|
||||
/// Format experiment for JSON output
|
||||
pub fn formatExperimentJson(
|
||||
allocator: std.mem.Allocator,
|
||||
id: []const u8,
|
||||
name: []const u8,
|
||||
lifecycle: []const u8,
|
||||
created: []const u8,
|
||||
) ![]const u8 {
|
||||
var buf = std.ArrayList(u8).init(allocator);
|
||||
defer buf.deinit();
|
||||
|
||||
const writer = buf.writer();
|
||||
try writer.print(
|
||||
"{{\"id\":\"{s}\",\"name\":\"{s}\",\"lifecycle\":\"{s}\",\"created\":\"{s}\"}}",
|
||||
.{ id, name, lifecycle, created },
|
||||
);
|
||||
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
||||
/// Format experiment list for JSON output
|
||||
pub fn formatExperimentListJson(
|
||||
allocator: std.mem.Allocator,
|
||||
experiments: []const Experiment,
|
||||
) ![]const u8 {
|
||||
var buf = std.ArrayList(u8).init(allocator);
|
||||
defer buf.deinit();
|
||||
|
||||
const writer = buf.writer();
|
||||
try writer.writeAll("[");
|
||||
|
||||
for (experiments, 0..) |exp, i| {
|
||||
if (i > 0) try writer.writeAll(",");
|
||||
try writer.print(
|
||||
"{{\"id\":\"{s}\",\"name\":\"{s}\",\"lifecycle\":\"{s}\",\"created\":\"{s}\"}}",
|
||||
.{ exp.id, exp.name, exp.lifecycle, exp.created },
|
||||
);
|
||||
}
|
||||
|
||||
try writer.writeAll("]");
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
||||
/// Experiment struct for shared use
|
||||
pub const Experiment = struct {
|
||||
id: []const u8,
|
||||
name: []const u8,
|
||||
lifecycle: []const u8,
|
||||
created: []const u8,
|
||||
};
|
||||
|
||||
/// Print experiment in text format
|
||||
pub fn printExperimentText(exp: Experiment) void {
|
||||
core.output.info(" {s} | {s} | {s} | {s}", .{
|
||||
exp.id,
|
||||
exp.name,
|
||||
exp.lifecycle,
|
||||
exp.created,
|
||||
});
|
||||
}
|
||||
|
||||
/// Format metric for JSON output
|
||||
pub fn formatMetricJson(
|
||||
allocator: std.mem.Allocator,
|
||||
name: []const u8,
|
||||
value: f64,
|
||||
step: u32,
|
||||
) ![]const u8 {
|
||||
var buf = std.ArrayList(u8).init(allocator);
|
||||
defer buf.deinit();
|
||||
|
||||
const writer = buf.writer();
|
||||
try writer.print(
|
||||
"{{\"name\":\"{s}\",\"value\":{d:.6},\"step\":{d}}}",
|
||||
.{ name, value, step },
|
||||
);
|
||||
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
||||
/// Validate metric value
|
||||
pub fn validateMetricValue(value: f64) bool {
|
||||
return !std.math.isNan(value) and !std.math.isInf(value);
|
||||
}
|
||||
|
||||
/// Format run ID
|
||||
pub fn formatRunId(allocator: std.mem.Allocator, experiment_id: []const u8, timestamp: i64) ![]const u8 {
|
||||
var buf: [64]u8 = undefined;
|
||||
const formatted = try std.fmt.bufPrint(&buf, "{s}_{d}", .{ experiment_id, timestamp });
|
||||
return try allocator.dupe(u8, formatted);
|
||||
}
|
||||
135
cli/src/core/flags.zig
Normal file
135
cli/src/core/flags.zig
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Common flags supported by most commands
|
||||
pub const CommonFlags = struct {
|
||||
json: bool = false,
|
||||
help: bool = false,
|
||||
verbose: bool = false,
|
||||
dry_run: bool = false,
|
||||
};
|
||||
|
||||
/// Parse common flags from command arguments
|
||||
/// Returns remaining non-flag arguments
|
||||
pub fn parseCommon(allocator: std.mem.Allocator, args: []const []const u8, flags: *CommonFlags) !std.ArrayList([]const u8) {
|
||||
var remaining = std.ArrayList([]const u8).initCapacity(allocator, args.len) catch |err| {
|
||||
return err;
|
||||
};
|
||||
errdefer remaining.deinit(allocator);
|
||||
|
||||
var i: usize = 0;
|
||||
while (i < args.len) : (i += 1) {
|
||||
const arg = args[i];
|
||||
if (std.mem.eql(u8, arg, "--json")) {
|
||||
flags.json = true;
|
||||
} else if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) {
|
||||
flags.help = true;
|
||||
} else if (std.mem.eql(u8, arg, "--verbose") or std.mem.eql(u8, arg, "-v")) {
|
||||
flags.verbose = true;
|
||||
} else if (std.mem.eql(u8, arg, "--dry-run")) {
|
||||
flags.dry_run = true;
|
||||
} else if (std.mem.eql(u8, arg, "--")) {
|
||||
// End of flags, rest are positional
|
||||
i += 1;
|
||||
while (i < args.len) : (i += 1) {
|
||||
try remaining.append(allocator, args[i]);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
try remaining.append(allocator, arg);
|
||||
}
|
||||
}
|
||||
|
||||
return remaining;
|
||||
}
|
||||
|
||||
/// Parse a key-value flag (--key=value or --key value)
|
||||
pub fn parseKVFlag(args: []const []const u8, key: []const u8) ?[]const u8 {
|
||||
const prefix = std.fmt.allocPrint(std.heap.page_allocator, "--{s}=", .{key}) catch return null;
|
||||
defer std.heap.page_allocator.free(prefix);
|
||||
|
||||
for (args) |arg| {
|
||||
if (std.mem.startsWith(u8, arg, prefix)) {
|
||||
return arg[prefix.len..];
|
||||
}
|
||||
}
|
||||
|
||||
// Check for --key value format
|
||||
var i: usize = 0;
|
||||
const key_only = std.fmt.allocPrint(std.heap.page_allocator, "--{s}", .{key}) catch return null;
|
||||
defer std.heap.page_allocator.free(key_only);
|
||||
|
||||
while (i < args.len) : (i += 1) {
|
||||
if (std.mem.eql(u8, args[i], key_only)) {
|
||||
if (i + 1 < args.len) {
|
||||
return args[i + 1];
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Parse a boolean flag
|
||||
pub fn parseBoolFlag(args: []const []const u8, flag: []const u8) bool {
|
||||
const full_flag = std.fmt.allocPrint(std.heap.page_allocator, "--{s}", .{flag}) catch return false;
|
||||
defer std.heap.page_allocator.free(full_flag);
|
||||
|
||||
for (args) |arg| {
|
||||
if (std.mem.eql(u8, arg, full_flag)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Parse numeric flag with default value
|
||||
pub fn parseNumFlag(comptime T: type, args: []const []const u8, flag: []const u8, default: T) T {
|
||||
const val_str = parseKVFlag(args, flag);
|
||||
if (val_str) |s| {
|
||||
return std.fmt.parseInt(T, s, 10) catch default;
|
||||
}
|
||||
return default;
|
||||
}
|
||||
|
||||
/// Check if args contain any of the given flags
|
||||
pub fn hasAnyFlag(args: []const []const u8, flags: []const []const u8) bool {
|
||||
for (args) |arg| {
|
||||
for (flags) |flag| {
|
||||
if (std.mem.eql(u8, arg, flag)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Shift/pop first argument
|
||||
pub fn shift(args: []const []const u8) ?[]const u8 {
|
||||
if (args.len == 0) return null;
|
||||
return args[0];
|
||||
}
|
||||
|
||||
/// Get remaining arguments after first
|
||||
pub fn rest(args: []const []const u8) []const []const u8 {
|
||||
if (args.len <= 1) return &[]const u8{};
|
||||
return args[1..];
|
||||
}
|
||||
|
||||
/// Require subcommand, return error if missing
|
||||
pub fn requireSubcommand(args: []const []const u8, comptime cmd_name: []const u8) ![]const u8 {
|
||||
if (args.len == 0) {
|
||||
std.log.err("Command '{s}' requires a subcommand", .{cmd_name});
|
||||
return error.MissingSubcommand;
|
||||
}
|
||||
return args[0];
|
||||
}
|
||||
|
||||
/// Match subcommand and return remaining args
|
||||
pub fn matchSubcommand(args: []const []const u8, comptime sub: []const u8) ?[]const []const u8 {
|
||||
if (args.len == 0) return null;
|
||||
if (std.mem.eql(u8, args[0], sub)) {
|
||||
return args[1..];
|
||||
}
|
||||
return null;
|
||||
}
|
||||
129
cli/src/core/output.zig
Normal file
129
cli/src/core/output.zig
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
|
||||
/// Output mode for commands
|
||||
pub const OutputMode = enum {
|
||||
text,
|
||||
json,
|
||||
};
|
||||
|
||||
/// Global output mode - set by main based on --json flag
|
||||
pub var global_mode: OutputMode = .text;
|
||||
|
||||
/// Initialize output mode from command flags
|
||||
pub fn init(mode: OutputMode) void {
|
||||
global_mode = mode;
|
||||
}
|
||||
|
||||
/// Print error in appropriate format
|
||||
pub fn errorMsg(comptime command: []const u8, message: []const u8) void {
|
||||
switch (global_mode) {
|
||||
.json => std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"{s}\",\"error\":\"{s}\"}}\n",
|
||||
.{ command, message },
|
||||
),
|
||||
.text => colors.printError("{s}\n", .{message}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Print error with additional details in appropriate format
|
||||
pub fn errorMsgDetailed(comptime command: []const u8, message: []const u8, details: []const u8) void {
|
||||
switch (global_mode) {
|
||||
.json => std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"{s}\",\"error\":\"{s}\",\"details\":\"{s}\"}}\n",
|
||||
.{ command, message, details },
|
||||
),
|
||||
.text => {
|
||||
colors.printError("{s}\n", .{message});
|
||||
std.debug.print("Details: {s}\n", .{details});
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Print success response in appropriate format (no data)
|
||||
pub fn success(comptime command: []const u8) void {
|
||||
switch (global_mode) {
|
||||
.json => std.debug.print("{{\"success\":true,\"command\":\"{s}\"}}\n", .{command}),
|
||||
.text => {}, // No output for text mode on simple success
|
||||
}
|
||||
}
|
||||
|
||||
/// Print success with string data
|
||||
pub fn successString(comptime command: []const u8, comptime data_key: []const u8, value: []const u8) void {
|
||||
switch (global_mode) {
|
||||
.json => std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"{s}\",\"data\":{{\"{s}\":\"{s}\"}}}}\n",
|
||||
.{ command, data_key, value },
|
||||
),
|
||||
.text => std.debug.print("{s}\n", .{value}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Print success with formatted string data
|
||||
pub fn successFmt(comptime command: []const u8, comptime fmt_str: []const u8, args: anytype) void {
|
||||
switch (global_mode) {
|
||||
.json => {
|
||||
// Use stack buffer to avoid allocation
|
||||
var buf: [4096]u8 = undefined;
|
||||
const msg = std.fmt.bufPrint(&buf, fmt_str, args) catch {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"{s}\",\"data\":null}}\n", .{command});
|
||||
return;
|
||||
};
|
||||
std.debug.print("{{\"success\":true,\"command\":\"{s}\",\"data\":{s}}}\n", .{ command, msg });
|
||||
},
|
||||
.text => std.debug.print(fmt_str ++ "\n", args),
|
||||
}
|
||||
}
|
||||
|
||||
/// Print informational message (text mode only)
|
||||
pub fn info(comptime fmt_str: []const u8, args: anytype) void {
|
||||
if (global_mode == .text) {
|
||||
std.debug.print(fmt_str ++ "\n", args);
|
||||
}
|
||||
}
|
||||
|
||||
/// Print usage information
|
||||
pub fn usage(comptime cmd: []const u8, comptime usage_str: []const u8) void {
|
||||
switch (global_mode) {
|
||||
.json => std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"{s}\",\"error\":\"Invalid arguments\",\"usage\":\"{s}\"}}\n",
|
||||
.{ cmd, usage_str },
|
||||
),
|
||||
.text => {
|
||||
std.debug.print("Usage: {s}\n", .{usage_str});
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Print unknown command error
|
||||
pub fn unknownCommand(comptime command: []const u8, unknown: []const u8) void {
|
||||
switch (global_mode) {
|
||||
.json => std.debug.print(
|
||||
"{{\"success\":false,\"command\":\"{s}\",\"error\":\"Unknown command: {s}\"}}\n",
|
||||
.{ command, unknown },
|
||||
),
|
||||
.text => colors.printError("Unknown command: {s}\n", .{unknown}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Print table header (text mode only)
|
||||
pub fn tableHeader(comptime cols: []const []const u8) void {
|
||||
if (global_mode == .json) return;
|
||||
|
||||
for (cols, 0..) |col, i| {
|
||||
if (i > 0) std.debug.print("\t", .{});
|
||||
std.debug.print("{s}", .{col});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
}
|
||||
|
||||
/// Print table row (text mode only)
|
||||
pub fn tableRow(values: []const []const u8) void {
|
||||
if (global_mode == .json) return;
|
||||
|
||||
for (values, 0..) |val, i| {
|
||||
if (i > 0) std.debug.print("\t", .{});
|
||||
std.debug.print("{s}", .{val});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
}
|
||||
264
cli/src/db.zig
Normal file
264
cli/src/db.zig
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
const std = @import("std");
|
||||
|
||||
// SQLite C bindings
|
||||
pub const c = @cImport({
|
||||
@cInclude("sqlite3.h");
|
||||
});
|
||||
|
||||
// Public type alias for prepared statement
|
||||
pub const Stmt = ?*c.sqlite3_stmt;
|
||||
|
||||
// SQLITE_TRANSIENT constant - use C wrapper to avoid Zig 0.15 C translation issue
|
||||
extern fn fetchml_sqlite_transient() c.sqlite3_destructor_type;
|
||||
fn sqliteTransient() c.sqlite3_destructor_type {
|
||||
return fetchml_sqlite_transient();
|
||||
}
|
||||
|
||||
// Schema for ML tracking tables
|
||||
const SCHEMA =
|
||||
\\ CREATE TABLE IF NOT EXISTS ml_experiments (
|
||||
\\ experiment_id TEXT PRIMARY KEY,
|
||||
\\ name TEXT NOT NULL,
|
||||
\\ artifact_path TEXT,
|
||||
\\ lifecycle TEXT DEFAULT 'active',
|
||||
\\ created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
\\ );
|
||||
\\ CREATE TABLE IF NOT EXISTS ml_runs (
|
||||
\\ run_id TEXT PRIMARY KEY,
|
||||
\\ experiment_id TEXT REFERENCES ml_experiments(experiment_id),
|
||||
\\ name TEXT,
|
||||
\\ status TEXT, -- RUNNING, FINISHED, FAILED, CANCELLED
|
||||
\\ start_time DATETIME,
|
||||
\\ end_time DATETIME,
|
||||
\\ artifact_uri TEXT,
|
||||
\\ pid INTEGER DEFAULT NULL,
|
||||
\\ synced INTEGER DEFAULT 0
|
||||
\\ );
|
||||
\\ CREATE TABLE IF NOT EXISTS ml_metrics (
|
||||
\\ run_id TEXT REFERENCES ml_runs(run_id),
|
||||
\\ key TEXT,
|
||||
\\ value REAL,
|
||||
\\ step INTEGER DEFAULT 0,
|
||||
\\ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
\\ );
|
||||
\\ CREATE TABLE IF NOT EXISTS ml_params (
|
||||
\\ run_id TEXT REFERENCES ml_runs(run_id),
|
||||
\\ key TEXT,
|
||||
\\ value TEXT
|
||||
\\ );
|
||||
\\ CREATE TABLE IF NOT EXISTS ml_tags (
|
||||
\\ run_id TEXT REFERENCES ml_runs(run_id),
|
||||
\\ key TEXT,
|
||||
\\ value TEXT
|
||||
\\ );
|
||||
;
|
||||
|
||||
/// Database connection handle
|
||||
pub const DB = struct {
|
||||
handle: ?*c.sqlite3,
|
||||
path: []const u8,
|
||||
|
||||
/// Initialize database with WAL mode and schema
|
||||
pub fn init(allocator: std.mem.Allocator, db_path: []const u8) !DB {
|
||||
var db: ?*c.sqlite3 = null;
|
||||
|
||||
// Open database
|
||||
const rc = c.sqlite3_open(db_path.ptr, &db);
|
||||
if (rc != c.SQLITE_OK) {
|
||||
std.log.err("Failed to open database: {s}", .{c.sqlite3_errmsg(db)});
|
||||
return error.DBOpenFailed;
|
||||
}
|
||||
|
||||
// Enable WAL mode - required for concurrent CLI writes and TUI reads
|
||||
var errmsg: [*c]u8 = null;
|
||||
_ = c.sqlite3_exec(db, "PRAGMA journal_mode=WAL;", null, null, &errmsg);
|
||||
if (errmsg != null) {
|
||||
c.sqlite3_free(errmsg);
|
||||
}
|
||||
|
||||
// Set synchronous=NORMAL for performance under WAL
|
||||
_ = c.sqlite3_exec(db, "PRAGMA synchronous=NORMAL;", null, null, &errmsg);
|
||||
if (errmsg != null) {
|
||||
c.sqlite3_free(errmsg);
|
||||
}
|
||||
|
||||
// Apply schema
|
||||
_ = c.sqlite3_exec(db, SCHEMA, null, null, &errmsg);
|
||||
if (errmsg != null) {
|
||||
std.log.err("Schema creation failed: {s}", .{errmsg});
|
||||
c.sqlite3_free(errmsg);
|
||||
_ = c.sqlite3_close(db);
|
||||
return error.SchemaFailed;
|
||||
}
|
||||
|
||||
const path_copy = try allocator.dupe(u8, db_path);
|
||||
|
||||
return DB{
|
||||
.handle = db,
|
||||
.path = path_copy,
|
||||
};
|
||||
}
|
||||
|
||||
/// Close database connection
|
||||
pub fn close(self: *DB) void {
|
||||
if (self.handle) |db| {
|
||||
_ = c.sqlite3_close(db);
|
||||
self.handle = null;
|
||||
}
|
||||
}
|
||||
|
||||
/// Checkpoint WAL on clean shutdown
|
||||
pub fn checkpointOnExit(self: *DB) void {
|
||||
if (self.handle) |db| {
|
||||
var errmsg: [*c]u8 = null;
|
||||
_ = c.sqlite3_exec(db, "PRAGMA wal_checkpoint(TRUNCATE);", null, null, &errmsg);
|
||||
if (errmsg != null) {
|
||||
c.sqlite3_free(errmsg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a simple SQL statement
|
||||
pub fn exec(self: DB, sql: []const u8) !void {
|
||||
if (self.handle == null) return error.DBNotOpen;
|
||||
|
||||
var errmsg: [*c]u8 = null;
|
||||
const rc = c.sqlite3_exec(self.handle, sql.ptr, null, null, &errmsg);
|
||||
|
||||
if (rc != c.SQLITE_OK) {
|
||||
if (errmsg) |e| {
|
||||
std.log.err("SQL error: {s}", .{e});
|
||||
c.sqlite3_free(errmsg);
|
||||
}
|
||||
return error.SQLExecFailed;
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare a statement
|
||||
pub fn prepare(self: DB, sql: []const u8) !?*c.sqlite3_stmt {
|
||||
if (self.handle == null) return error.DBNotOpen;
|
||||
|
||||
var stmt: ?*c.sqlite3_stmt = null;
|
||||
const rc = c.sqlite3_prepare_v2(self.handle, sql.ptr, @intCast(sql.len), &stmt, null);
|
||||
|
||||
if (rc != c.SQLITE_OK) {
|
||||
std.log.err("Prepare failed: {s}", .{c.sqlite3_errmsg(self.handle)});
|
||||
return error.PrepareFailed;
|
||||
}
|
||||
|
||||
return stmt;
|
||||
}
|
||||
|
||||
/// Finalize a prepared statement
|
||||
pub fn finalize(stmt: ?*c.sqlite3_stmt) void {
|
||||
if (stmt) |s| {
|
||||
_ = c.sqlite3_finalize(s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Bind text parameter to statement
|
||||
pub fn bindText(stmt: ?*c.sqlite3_stmt, idx: i32, value: []const u8) !void {
|
||||
if (stmt == null) return error.InvalidStatement;
|
||||
const rc = c.sqlite3_bind_text(stmt, idx, value.ptr, @intCast(value.len), sqliteTransient());
|
||||
if (rc != c.SQLITE_OK) return error.BindFailed;
|
||||
}
|
||||
|
||||
/// Bind int64 parameter to statement
|
||||
pub fn bindInt64(stmt: ?*c.sqlite3_stmt, idx: i32, value: i64) !void {
|
||||
if (stmt == null) return error.InvalidStatement;
|
||||
const rc = c.sqlite3_bind_int64(stmt, idx, value);
|
||||
if (rc != c.SQLITE_OK) return error.BindFailed;
|
||||
}
|
||||
|
||||
/// Bind double parameter to statement
|
||||
pub fn bindDouble(stmt: ?*c.sqlite3_stmt, idx: i32, value: f64) !void {
|
||||
if (stmt == null) return error.InvalidStatement;
|
||||
const rc = c.sqlite3_bind_double(stmt, idx, value);
|
||||
if (rc != c.SQLITE_OK) return error.BindFailed;
|
||||
}
|
||||
|
||||
/// Step statement (execute)
|
||||
pub fn step(stmt: ?*c.sqlite3_stmt) !bool {
|
||||
if (stmt == null) return error.InvalidStatement;
|
||||
const rc = c.sqlite3_step(stmt);
|
||||
return rc == c.SQLITE_ROW; // true if has row, false if done
|
||||
}
|
||||
|
||||
/// Reset statement for reuse
|
||||
pub fn reset(stmt: ?*c.sqlite3_stmt) !void {
|
||||
if (stmt == null) return error.InvalidStatement;
|
||||
_ = c.sqlite3_reset(stmt);
|
||||
_ = c.sqlite3_clear_bindings(stmt);
|
||||
}
|
||||
|
||||
/// Get column text
|
||||
pub fn columnText(stmt: ?*c.sqlite3_stmt, idx: i32) []const u8 {
|
||||
if (stmt == null) return "";
|
||||
const ptr = c.sqlite3_column_text(stmt, idx);
|
||||
const len = c.sqlite3_column_bytes(stmt, idx);
|
||||
if (ptr == null or len == 0) return "";
|
||||
return ptr[0..@intCast(len)];
|
||||
}
|
||||
|
||||
/// Get column int64
|
||||
pub fn columnInt64(stmt: ?*c.sqlite3_stmt, idx: i32) i64 {
|
||||
if (stmt == null) return 0;
|
||||
return c.sqlite3_column_int64(stmt, idx);
|
||||
}
|
||||
|
||||
/// Get column double
|
||||
pub fn columnDouble(stmt: ?*c.sqlite3_stmt, idx: i32) f64 {
|
||||
if (stmt == null) return 0.0;
|
||||
return c.sqlite3_column_double(stmt, idx);
|
||||
}
|
||||
};
|
||||
|
||||
/// Generate UUID v4 (simple random-based)
|
||||
pub fn generateUUID(allocator: std.mem.Allocator) ![]const u8 {
|
||||
var buf: [36]u8 = undefined;
|
||||
const hex_chars = "0123456789abcdef";
|
||||
|
||||
// Random bytes (simplified - in production use crypto RNG)
|
||||
var bytes: [16]u8 = undefined;
|
||||
std.crypto.random.bytes(&bytes);
|
||||
|
||||
// Set version (4) and variant bits
|
||||
bytes[6] = (bytes[6] & 0x0f) | 0x40;
|
||||
bytes[8] = (bytes[8] & 0x3f) | 0x80;
|
||||
|
||||
// Format as UUID string
|
||||
var idx: usize = 0;
|
||||
for (0..16) |i| {
|
||||
if (i == 4 or i == 6 or i == 8 or i == 10) {
|
||||
buf[idx] = '-';
|
||||
idx += 1;
|
||||
}
|
||||
buf[idx] = hex_chars[bytes[i] >> 4];
|
||||
buf[idx + 1] = hex_chars[bytes[i] & 0x0f];
|
||||
idx += 2;
|
||||
}
|
||||
|
||||
return try allocator.dupe(u8, &buf);
|
||||
}
|
||||
|
||||
/// Get current timestamp as ISO8601 string
|
||||
pub fn currentTimestamp(allocator: std.mem.Allocator) ![]const u8 {
|
||||
const now = std.time.timestamp();
|
||||
const epoch_seconds = std.time.epoch.EpochSeconds{ .secs = @intCast(now) };
|
||||
const epoch_day = epoch_seconds.getEpochDay();
|
||||
const year_day = epoch_day.calculateYearDay();
|
||||
const month_day = year_day.calculateMonthDay();
|
||||
const day_seconds = epoch_seconds.getDaySeconds();
|
||||
|
||||
var buf: [20]u8 = undefined;
|
||||
const len = try std.fmt.bufPrint(&buf, "{d:0>4}-{d:0>2}-{d:0>2} {d:0>2}:{d:0>2}:{d:0>2}", .{
|
||||
year_day.year,
|
||||
month_day.month.numeric(),
|
||||
month_day.day_index + 1,
|
||||
day_seconds.getHoursIntoDay(),
|
||||
day_seconds.getMinutesIntoHour(),
|
||||
day_seconds.getSecondsIntoMinute(),
|
||||
});
|
||||
|
||||
return try allocator.dupe(u8, len);
|
||||
}
|
||||
22
cli/src/local.zig
Normal file
22
cli/src/local.zig
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
//! Local mode operations module
|
||||
//!
|
||||
//! Provides implementations for CLI commands when running in local mode (SQLite).
|
||||
//! These functions are called by the command routers in `src/commands/` when
|
||||
//! `Context.isLocal()` returns true.
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```zig
|
||||
//! const local = @import("../local.zig");
|
||||
//!
|
||||
//! if (ctx.isLocal()) {
|
||||
//! return try local.experiment.create(ctx.allocator, name, artifact_path, json);
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! ## Module Structure
|
||||
//!
|
||||
//! - `experiment_ops.zig` - Experiment CRUD operations for SQLite
|
||||
//! - Future: `run_ops.zig`, `metrics_ops.zig`, etc.
|
||||
|
||||
pub const experiment = @import("local/experiment_ops.zig");
|
||||
167
cli/src/local/experiment_ops.zig
Normal file
167
cli/src/local/experiment_ops.zig
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
const std = @import("std");
|
||||
const db = @import("../db.zig");
|
||||
const config = @import("../config.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
pub const Experiment = struct {
|
||||
id: []const u8,
|
||||
name: []const u8,
|
||||
lifecycle: []const u8,
|
||||
created: []const u8,
|
||||
};
|
||||
|
||||
/// Create a new experiment in local mode
|
||||
pub fn create(allocator: std.mem.Allocator, name: []const u8, artifact_path: ?[]const u8, json: bool) !void {
|
||||
// Load config
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
if (!cfg.isLocalMode()) {
|
||||
if (json) {
|
||||
core.output.errorMsg("experiment.create", "create only works in local mode (sqlite://)");
|
||||
} else {
|
||||
std.log.err("Error: experiment create only works in local mode (sqlite://)", .{});
|
||||
}
|
||||
return error.NotLocalMode;
|
||||
}
|
||||
|
||||
// Get DB path
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
// Initialize DB
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
// Generate experiment ID
|
||||
const exp_id = try db.generateUUID(allocator);
|
||||
defer allocator.free(exp_id);
|
||||
|
||||
// Insert experiment
|
||||
const sql = "INSERT INTO ml_experiments (experiment_id, name, artifact_path) VALUES (?, ?, ?);";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
try db.DB.bindText(stmt, 1, exp_id);
|
||||
try db.DB.bindText(stmt, 2, name);
|
||||
try db.DB.bindText(stmt, 3, artifact_path orelse "");
|
||||
|
||||
_ = try db.DB.step(stmt);
|
||||
database.checkpointOnExit();
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"experiment.create\",\"data\":{{\"experiment_id\":\"{s}\",\"name\":\"{s}\"}}}}\n", .{ exp_id, name });
|
||||
} else {
|
||||
std.debug.print("Created experiment: {s} (ID: {s})\n", .{ name, exp_id });
|
||||
}
|
||||
}
|
||||
|
||||
/// Log a metric for a run in local mode
|
||||
pub fn logMetric(
|
||||
allocator: std.mem.Allocator,
|
||||
run_id: []const u8,
|
||||
name: []const u8,
|
||||
value: f64,
|
||||
step: i64,
|
||||
json: bool,
|
||||
) !void {
|
||||
// Load config
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Get DB path
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
// Initialize DB
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
// Insert metric
|
||||
const sql = "INSERT INTO ml_metrics (run_id, key, value, step) VALUES (?, ?, ?, ?);";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
try db.DB.bindText(stmt, 1, run_id);
|
||||
try db.DB.bindText(stmt, 2, name);
|
||||
try db.DB.bindDouble(stmt, 3, value);
|
||||
try db.DB.bindInt64(stmt, 4, step);
|
||||
|
||||
_ = try db.DB.step(stmt);
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"experiment.log\",\"data\":{{\"run_id\":\"{s}\",\"metric\":{{\"name\":\"{s}\",\"value\":{d},\"step\":{d}}}}}}}\n", .{ run_id, name, value, step });
|
||||
} else {
|
||||
std.debug.print("Logged metric: {s} = {d:.4} (step {d})\n", .{ name, value, step });
|
||||
}
|
||||
}
|
||||
|
||||
/// List all experiments in local mode
|
||||
pub fn list(allocator: std.mem.Allocator, json: bool) !void {
|
||||
// Load config
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
// Get DB path
|
||||
const db_path = try cfg.getDBPath(allocator);
|
||||
defer allocator.free(db_path);
|
||||
|
||||
// Initialize DB
|
||||
var database = try db.DB.init(allocator, db_path);
|
||||
defer database.close();
|
||||
|
||||
// Query experiments
|
||||
const sql = "SELECT experiment_id, name, lifecycle, created_at FROM ml_experiments ORDER BY created_at DESC;";
|
||||
const stmt = try database.prepare(sql);
|
||||
defer db.DB.finalize(stmt);
|
||||
|
||||
var experiments = std.ArrayList(Experiment).initCapacity(allocator, 10) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer {
|
||||
for (experiments.items) |exp| {
|
||||
allocator.free(exp.id);
|
||||
allocator.free(exp.name);
|
||||
allocator.free(exp.lifecycle);
|
||||
allocator.free(exp.created);
|
||||
}
|
||||
experiments.deinit(allocator);
|
||||
}
|
||||
|
||||
while (try db.DB.step(stmt)) {
|
||||
const id = try allocator.dupe(u8, db.DB.columnText(stmt, 0));
|
||||
const name = try allocator.dupe(u8, db.DB.columnText(stmt, 1));
|
||||
const lifecycle = try allocator.dupe(u8, db.DB.columnText(stmt, 2));
|
||||
const created = try allocator.dupe(u8, db.DB.columnText(stmt, 3));
|
||||
try experiments.append(allocator, .{ .id = id, .name = name, .lifecycle = lifecycle, .created = created });
|
||||
}
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"experiment.list\",\"data\":{{\"experiments\":[", .{});
|
||||
for (experiments.items, 0..) |exp, idx| {
|
||||
if (idx > 0) std.debug.print(",", .{});
|
||||
std.debug.print("{{\"experiment_id\":\"{s}\",\"name\":\"{s}\",\"lifecycle\":\"{s}\",\"created_at\":\"{s}\"}}", .{ exp.id, exp.name, exp.lifecycle, exp.created });
|
||||
}
|
||||
std.debug.print("],\"total\":{d}}}}}\n", .{experiments.items.len});
|
||||
} else {
|
||||
if (experiments.items.len == 0) {
|
||||
std.debug.print("No experiments found. Create one with: ml experiment create --name <name>\n", .{});
|
||||
} else {
|
||||
std.debug.print("\nExperiments:\n", .{});
|
||||
std.debug.print("{s:-<60}\n", .{""});
|
||||
for (experiments.items) |exp| {
|
||||
std.debug.print("{s} | {s} | {s} | {s}\n", .{ exp.id, exp.name, exp.lifecycle, exp.created });
|
||||
}
|
||||
std.debug.print("\nTotal: {d} experiments\n", .{experiments.items.len});
|
||||
}
|
||||
}
|
||||
}
|
||||
101
cli/src/main.zig
101
cli/src/main.zig
|
|
@ -12,15 +12,14 @@ pub fn main() !void {
|
|||
// Initialize colors based on environment
|
||||
colors.initColors();
|
||||
|
||||
// Use ArenaAllocator for thread-safe memory management
|
||||
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
|
||||
defer arena.deinit();
|
||||
const allocator = arena.allocator();
|
||||
// Use c_allocator for better performance on Linux
|
||||
const allocator = std.heap.c_allocator;
|
||||
|
||||
const args = std.process.argsAlloc(allocator) catch |err| {
|
||||
std.debug.print("Failed to allocate args: {}\n", .{err});
|
||||
return;
|
||||
};
|
||||
defer std.process.argsFree(allocator, args);
|
||||
|
||||
if (args.len < 2) {
|
||||
printUsage();
|
||||
|
|
@ -33,48 +32,56 @@ pub fn main() !void {
|
|||
switch (command[0]) {
|
||||
'j' => if (std.mem.eql(u8, command, "jupyter")) {
|
||||
try @import("commands/jupyter.zig").run(allocator, args[2..]);
|
||||
},
|
||||
} else handleUnknownCommand(command),
|
||||
'i' => if (std.mem.eql(u8, command, "init")) {
|
||||
colors.printInfo("Setup configuration interactively\n", .{});
|
||||
try @import("commands/init.zig").run(allocator, args[2..]);
|
||||
} else if (std.mem.eql(u8, command, "info")) {
|
||||
try @import("commands/info.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'a' => if (std.mem.eql(u8, command, "annotate")) {
|
||||
try @import("commands/annotate.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'n' => if (std.mem.eql(u8, command, "narrative")) {
|
||||
try @import("commands/narrative.zig").run(allocator, args[2..]);
|
||||
},
|
||||
try @import("commands/annotate.zig").execute(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'e' => if (std.mem.eql(u8, command, "experiment")) {
|
||||
try @import("commands/experiment.zig").execute(allocator, args[2..]);
|
||||
} else if (std.mem.eql(u8, command, "export")) {
|
||||
try @import("commands/export_cmd.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
's' => if (std.mem.eql(u8, command, "sync")) {
|
||||
if (args.len < 3) {
|
||||
colors.printError("Usage: ml sync <path>\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
colors.printInfo("Sync project to server: {s}\n", .{args[2]});
|
||||
try @import("commands/sync.zig").run(allocator, args[2..]);
|
||||
} else if (std.mem.eql(u8, command, "status")) {
|
||||
try @import("commands/status.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'r' => if (std.mem.eql(u8, command, "requeue")) {
|
||||
try @import("commands/requeue.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'r' => if (std.mem.eql(u8, command, "run")) {
|
||||
try @import("commands/run.zig").execute(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'q' => if (std.mem.eql(u8, command, "queue")) {
|
||||
try @import("commands/queue.zig").run(allocator, args[2..]);
|
||||
},
|
||||
} else handleUnknownCommand(command),
|
||||
'd' => if (std.mem.eql(u8, command, "dataset")) {
|
||||
try @import("commands/dataset.zig").run(allocator, args[2..]);
|
||||
},
|
||||
'e' => if (std.mem.eql(u8, command, "experiment")) {
|
||||
try @import("commands/experiment.zig").execute(allocator, args[2..]);
|
||||
},
|
||||
} else handleUnknownCommand(command),
|
||||
'x' => if (std.mem.eql(u8, command, "export")) {
|
||||
try @import("commands/export_cmd.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'c' => if (std.mem.eql(u8, command, "cancel")) {
|
||||
try @import("commands/cancel.zig").run(allocator, args[2..]);
|
||||
},
|
||||
} else if (std.mem.eql(u8, command, "compare")) {
|
||||
try @import("commands/compare.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'f' => if (std.mem.eql(u8, command, "find")) {
|
||||
try @import("commands/find.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'v' => if (std.mem.eql(u8, command, "validate")) {
|
||||
try @import("commands/validate.zig").run(allocator, args[2..]);
|
||||
},
|
||||
} else handleUnknownCommand(command),
|
||||
'l' => if (std.mem.eql(u8, command, "logs")) {
|
||||
try @import("commands/logs.zig").run(allocator, args[2..]);
|
||||
},
|
||||
try @import("commands/log.zig").run(allocator, args[2..]);
|
||||
} else if (std.mem.eql(u8, command, "log")) {
|
||||
try @import("commands/log.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
'w' => if (std.mem.eql(u8, command, "watch")) {
|
||||
try @import("commands/watch.zig").run(allocator, args[2..]);
|
||||
} else handleUnknownCommand(command),
|
||||
else => {
|
||||
colors.printError("Unknown command: {s}\n", .{args[1]});
|
||||
printUsage();
|
||||
|
|
@ -88,30 +95,32 @@ fn printUsage() void {
|
|||
colors.printInfo("ML Experiment Manager\n\n", .{});
|
||||
std.debug.print("Usage: ml <command> [options]\n\n", .{});
|
||||
std.debug.print("Commands:\n", .{});
|
||||
std.debug.print(" jupyter Jupyter workspace management\n", .{});
|
||||
std.debug.print(" init Setup configuration interactively\n", .{});
|
||||
std.debug.print(" annotate <path|id> Add an annotation to run_manifest.json (--note \"...\")\n", .{});
|
||||
std.debug.print(" narrative set <path|id> Set run narrative fields (hypothesis/context/...)\n", .{});
|
||||
std.debug.print(" info <path|id> Show run info from run_manifest.json (optionally --base <path>)\n", .{});
|
||||
std.debug.print(" sync <path> Sync project to server\n", .{});
|
||||
std.debug.print(" requeue <id> Re-submit from run_id/task_id/path (supports -- <args>)\n", .{});
|
||||
std.debug.print(" queue (q) <job> Queue job for execution\n", .{});
|
||||
std.debug.print(" init Initialize project with config (use --local for SQLite)\n", .{});
|
||||
std.debug.print(" run [args] Execute a run locally (forks, captures, parses metrics)\n", .{});
|
||||
std.debug.print(" queue <job> Queue job on server (--rerun <id> to re-queue local run)\n", .{});
|
||||
std.debug.print(" annotate <id> Add metadata annotations (hypothesis/outcome/confidence)\n", .{});
|
||||
std.debug.print(" experiment Manage experiments (create, list, show)\n", .{});
|
||||
std.debug.print(" logs <id> Fetch or stream run logs (--follow for live tail)\n", .{});
|
||||
std.debug.print(" sync [id] Push local runs to server (sync_run + sync_ack protocol)\n", .{});
|
||||
std.debug.print(" cancel <id> Cancel local run (SIGTERM/SIGKILL by PID)\n", .{});
|
||||
std.debug.print(" watch [--sync] Watch directory with optional auto-sync\n", .{});
|
||||
std.debug.print(" status Get system status\n", .{});
|
||||
std.debug.print(" monitor Launch TUI via SSH\n", .{});
|
||||
std.debug.print(" logs <id> Fetch job logs (-f to follow, -n for tail)\n", .{});
|
||||
std.debug.print(" cancel <job> Cancel running job\n", .{});
|
||||
std.debug.print(" prune Remove old experiments\n", .{});
|
||||
std.debug.print(" watch <path> Watch directory for auto-sync\n", .{});
|
||||
std.debug.print(" dataset Manage datasets\n", .{});
|
||||
std.debug.print(" experiment Manage experiments and metrics\n", .{});
|
||||
std.debug.print(" validate Validate provenance and integrity for a commit/task\n", .{});
|
||||
std.debug.print(" export <id> Export experiment bundle\n", .{});
|
||||
std.debug.print(" validate Validate provenance and integrity\n", .{});
|
||||
std.debug.print(" compare <a> <b> Compare two runs\n", .{});
|
||||
std.debug.print(" find [query] Search experiments\n", .{});
|
||||
std.debug.print(" jupyter Jupyter workspace management\n", .{});
|
||||
std.debug.print(" info <id> Show run info\n", .{});
|
||||
std.debug.print("\nUse 'ml <command> --help' for detailed help.\n", .{});
|
||||
}
|
||||
|
||||
test {
|
||||
_ = @import("commands/info.zig");
|
||||
_ = @import("commands/requeue.zig");
|
||||
_ = @import("commands/compare.zig");
|
||||
_ = @import("commands/find.zig");
|
||||
_ = @import("commands/export_cmd.zig");
|
||||
_ = @import("commands/log.zig");
|
||||
_ = @import("commands/annotate.zig");
|
||||
_ = @import("commands/narrative.zig");
|
||||
_ = @import("commands/logs.zig");
|
||||
_ = @import("commands/experiment.zig");
|
||||
}
|
||||
|
|
|
|||
383
cli/src/manifest.zig
Normal file
383
cli/src/manifest.zig
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// RunManifest represents a run manifest - identical schema between local and server
|
||||
/// Schema compatibility is a hard requirement enforced here
|
||||
pub const RunManifest = struct {
|
||||
run_id: []const u8,
|
||||
experiment: []const u8,
|
||||
command: []const u8,
|
||||
args: [][]const u8,
|
||||
commit_id: ?[]const u8,
|
||||
started_at: []const u8,
|
||||
ended_at: ?[]const u8,
|
||||
status: []const u8, // RUNNING, FINISHED, FAILED, CANCELLED
|
||||
exit_code: ?i32,
|
||||
params: std.StringHashMap([]const u8),
|
||||
metrics_summary: ?std.StringHashMap(f64),
|
||||
artifact_path: []const u8,
|
||||
synced: bool,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) RunManifest {
|
||||
return .{
|
||||
.run_id = "",
|
||||
.experiment = "",
|
||||
.command = "",
|
||||
.args = &[_][]const u8{},
|
||||
.commit_id = null,
|
||||
.started_at = "",
|
||||
.ended_at = null,
|
||||
.status = "RUNNING",
|
||||
.exit_code = null,
|
||||
.params = std.StringHashMap([]const u8).init(allocator),
|
||||
.metrics_summary = null,
|
||||
.artifact_path = "",
|
||||
.synced = false,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *RunManifest, allocator: std.mem.Allocator) void {
|
||||
var params_iter = self.params.iterator();
|
||||
while (params_iter.next()) |entry| {
|
||||
allocator.free(entry.key_ptr.*);
|
||||
allocator.free(entry.value_ptr.*);
|
||||
}
|
||||
self.params.deinit();
|
||||
|
||||
if (self.metrics_summary) |*summary| {
|
||||
var summary_iter = summary.iterator();
|
||||
while (summary_iter.next()) |entry| {
|
||||
allocator.free(entry.key_ptr.*);
|
||||
}
|
||||
summary.deinit();
|
||||
}
|
||||
|
||||
for (self.args) |arg| {
|
||||
allocator.free(arg);
|
||||
}
|
||||
allocator.free(self.args);
|
||||
}
|
||||
};
|
||||
|
||||
/// Write manifest to JSON file
|
||||
pub fn writeManifest(manifest: RunManifest, path: []const u8, allocator: std.mem.Allocator) !void {
|
||||
var file = try std.fs.cwd().createFile(path, .{});
|
||||
defer file.close();
|
||||
|
||||
// Write JSON manually to avoid std.json complexity with hash maps
|
||||
try file.writeAll("{\n");
|
||||
|
||||
const line1 = try std.fmt.allocPrint(allocator, " \"run_id\": \"{s}\",\n", .{manifest.run_id});
|
||||
defer allocator.free(line1);
|
||||
try file.writeAll(line1);
|
||||
|
||||
const line2 = try std.fmt.allocPrint(allocator, " \"experiment\": \"{s}\",\n", .{manifest.experiment});
|
||||
defer allocator.free(line2);
|
||||
try file.writeAll(line2);
|
||||
|
||||
const line3 = try std.fmt.allocPrint(allocator, " \"command\": \"{s}\",\n", .{manifest.command});
|
||||
defer allocator.free(line3);
|
||||
try file.writeAll(line3);
|
||||
|
||||
// Args array
|
||||
try file.writeAll(" \"args\": [");
|
||||
for (manifest.args, 0..) |arg, i| {
|
||||
if (i > 0) try file.writeAll(", ");
|
||||
const arg_str = try std.fmt.allocPrint(allocator, "\"{s}\"", .{arg});
|
||||
defer allocator.free(arg_str);
|
||||
try file.writeAll(arg_str);
|
||||
}
|
||||
try file.writeAll("],\n");
|
||||
|
||||
// Commit ID (optional)
|
||||
if (manifest.commit_id) |cid| {
|
||||
const cid_str = try std.fmt.allocPrint(allocator, " \"commit_id\": \"{s}\",\n", .{cid});
|
||||
defer allocator.free(cid_str);
|
||||
try file.writeAll(cid_str);
|
||||
} else {
|
||||
try file.writeAll(" \"commit_id\": null,\n");
|
||||
}
|
||||
|
||||
const started_str = try std.fmt.allocPrint(allocator, " \"started_at\": \"{s}\",\n", .{manifest.started_at});
|
||||
defer allocator.free(started_str);
|
||||
try file.writeAll(started_str);
|
||||
|
||||
// Ended at (optional)
|
||||
if (manifest.ended_at) |ended| {
|
||||
const ended_str = try std.fmt.allocPrint(allocator, " \"ended_at\": \"{s}\",\n", .{ended});
|
||||
defer allocator.free(ended_str);
|
||||
try file.writeAll(ended_str);
|
||||
} else {
|
||||
try file.writeAll(" \"ended_at\": null,\n");
|
||||
}
|
||||
|
||||
const status_str = try std.fmt.allocPrint(allocator, " \"status\": \"{s}\",\n", .{manifest.status});
|
||||
defer allocator.free(status_str);
|
||||
try file.writeAll(status_str);
|
||||
|
||||
// Exit code (optional)
|
||||
if (manifest.exit_code) |code| {
|
||||
const exit_str = try std.fmt.allocPrint(allocator, " \"exit_code\": {d},\n", .{code});
|
||||
defer allocator.free(exit_str);
|
||||
try file.writeAll(exit_str);
|
||||
} else {
|
||||
try file.writeAll(" \"exit_code\": null,\n");
|
||||
}
|
||||
|
||||
// Params object
|
||||
try file.writeAll(" \"params\": {");
|
||||
var params_first = true;
|
||||
var params_iter = manifest.params.iterator();
|
||||
while (params_iter.next()) |entry| {
|
||||
if (!params_first) try file.writeAll(", ");
|
||||
params_first = false;
|
||||
const param_str = try std.fmt.allocPrint(allocator, "\"{s}\": \"{s}\"", .{ entry.key_ptr.*, entry.value_ptr.* });
|
||||
defer allocator.free(param_str);
|
||||
try file.writeAll(param_str);
|
||||
}
|
||||
try file.writeAll("},\n");
|
||||
|
||||
// Metrics summary (optional)
|
||||
if (manifest.metrics_summary) |summary| {
|
||||
try file.writeAll(" \"metrics_summary\": {");
|
||||
var summary_first = true;
|
||||
var summary_iter = summary.iterator();
|
||||
while (summary_iter.next()) |entry| {
|
||||
if (!summary_first) try file.writeAll(", ");
|
||||
summary_first = false;
|
||||
const metric_str = try std.fmt.allocPrint(allocator, "\"{s}\": {d:.4}", .{ entry.key_ptr.*, entry.value_ptr.* });
|
||||
defer allocator.free(metric_str);
|
||||
try file.writeAll(metric_str);
|
||||
}
|
||||
try file.writeAll("},\n");
|
||||
} else {
|
||||
try file.writeAll(" \"metrics_summary\": null,\n");
|
||||
}
|
||||
|
||||
const artifact_str = try std.fmt.allocPrint(allocator, " \"artifact_path\": \"{s}\",\n", .{manifest.artifact_path});
|
||||
defer allocator.free(artifact_str);
|
||||
try file.writeAll(artifact_str);
|
||||
|
||||
const synced_str = try std.fmt.allocPrint(allocator, " \"synced\": {}", .{manifest.synced});
|
||||
defer allocator.free(synced_str);
|
||||
try file.writeAll(synced_str);
|
||||
|
||||
try file.writeAll("\n}\n");
|
||||
}
|
||||
|
||||
/// Read manifest from JSON file
|
||||
pub fn readManifest(path: []const u8, allocator: std.mem.Allocator) !RunManifest {
|
||||
var file = try std.fs.cwd().openFile(path, .{});
|
||||
defer file.close();
|
||||
|
||||
const content = try file.readToEndAlloc(allocator, 1024 * 1024);
|
||||
defer allocator.free(content);
|
||||
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, allocator, content, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
if (parsed.value != .object) {
|
||||
return error.InvalidManifest;
|
||||
}
|
||||
|
||||
const root = parsed.value.object;
|
||||
var manifest = RunManifest.init(allocator);
|
||||
|
||||
// Required fields
|
||||
manifest.run_id = try getStringField(allocator, root, "run_id") orelse return error.MissingRunId;
|
||||
manifest.experiment = try getStringField(allocator, root, "experiment") orelse return error.MissingExperiment;
|
||||
manifest.command = try getStringField(allocator, root, "command") orelse return error.MissingCommand;
|
||||
manifest.status = try getStringField(allocator, root, "status") orelse "RUNNING";
|
||||
manifest.started_at = try getStringField(allocator, root, "started_at") orelse "";
|
||||
|
||||
// Optional fields
|
||||
manifest.ended_at = try getStringField(allocator, root, "ended_at");
|
||||
manifest.commit_id = try getStringField(allocator, root, "commit_id");
|
||||
manifest.artifact_path = try getStringField(allocator, root, "artifact_path") orelse "";
|
||||
|
||||
// Synced boolean
|
||||
if (root.get("synced")) |synced_val| {
|
||||
if (synced_val == .bool) {
|
||||
manifest.synced = synced_val.bool;
|
||||
}
|
||||
}
|
||||
|
||||
// Exit code
|
||||
if (root.get("exit_code")) |exit_val| {
|
||||
if (exit_val == .integer) {
|
||||
manifest.exit_code = @intCast(exit_val.integer);
|
||||
}
|
||||
}
|
||||
|
||||
// Args array
|
||||
if (root.get("args")) |args_val| {
|
||||
if (args_val == .array) {
|
||||
const args = try allocator.alloc([]const u8, args_val.array.items.len);
|
||||
for (args_val.array.items, 0..) |arg, i| {
|
||||
if (arg == .string) {
|
||||
args[i] = try allocator.dupe(u8, arg.string);
|
||||
}
|
||||
}
|
||||
manifest.args = args;
|
||||
}
|
||||
}
|
||||
|
||||
// Params object
|
||||
if (root.get("params")) |params_val| {
|
||||
if (params_val == .object) {
|
||||
var params_iter = params_val.object.iterator();
|
||||
while (params_iter.next()) |entry| {
|
||||
if (entry.value_ptr.* == .string) {
|
||||
const key = try allocator.dupe(u8, entry.key_ptr.*);
|
||||
const value = try allocator.dupe(u8, entry.value_ptr.*.string);
|
||||
try manifest.params.put(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics summary
|
||||
if (root.get("metrics_summary")) |metrics_val| {
|
||||
if (metrics_val == .object) {
|
||||
var summary = std.StringHashMap(f64).init(allocator);
|
||||
var metrics_iter = metrics_val.object.iterator();
|
||||
while (metrics_iter.next()) |entry| {
|
||||
const val = entry.value_ptr.*;
|
||||
if (val == .float) {
|
||||
const key = try allocator.dupe(u8, entry.key_ptr.*);
|
||||
try summary.put(key, val.float);
|
||||
} else if (val == .integer) {
|
||||
const key = try allocator.dupe(u8, entry.key_ptr.*);
|
||||
try summary.put(key, @floatFromInt(val.integer));
|
||||
}
|
||||
}
|
||||
manifest.metrics_summary = summary;
|
||||
}
|
||||
}
|
||||
|
||||
return manifest;
|
||||
}
|
||||
|
||||
/// Get string field from JSON object, duplicating the string
|
||||
fn getStringField(allocator: std.mem.Allocator, obj: std.json.ObjectMap, field: []const u8) !?[]const u8 {
|
||||
const val = obj.get(field) orelse return null;
|
||||
if (val != .string) return null;
|
||||
return try allocator.dupe(u8, val.string);
|
||||
}
|
||||
|
||||
/// Update manifest status and ended_at on run completion
|
||||
pub fn updateManifestStatus(path: []const u8, status: []const u8, exit_code: ?i32, allocator: std.mem.Allocator) !void {
|
||||
var manifest = try readManifest(path, allocator);
|
||||
defer manifest.deinit(allocator);
|
||||
|
||||
manifest.status = status;
|
||||
manifest.exit_code = exit_code;
|
||||
|
||||
// Set ended_at to current timestamp
|
||||
const now = std.time.timestamp();
|
||||
const epoch_seconds = std.time.epoch.EpochSeconds{ .secs = @intCast(now) };
|
||||
const epoch_day = epoch_seconds.getEpochDay();
|
||||
const year_day = epoch_day.calculateYearDay();
|
||||
const month_day = year_day.calculateMonthDay();
|
||||
const day_seconds = epoch_seconds.getDaySeconds();
|
||||
|
||||
var buf: [30]u8 = undefined;
|
||||
const timestamp = std.fmt.bufPrint(&buf, "{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}Z", .{
|
||||
year_day.year,
|
||||
month_day.month.numeric(),
|
||||
month_day.day_index + 1,
|
||||
day_seconds.getHoursIntoDay(),
|
||||
day_seconds.getMinutesIntoHour(),
|
||||
day_seconds.getSecondsIntoMinute(),
|
||||
}) catch unreachable;
|
||||
|
||||
manifest.ended_at = try allocator.dupe(u8, timestamp);
|
||||
|
||||
try writeManifest(manifest, path, allocator);
|
||||
}
|
||||
|
||||
/// Mark manifest as synced
|
||||
pub fn markManifestSynced(path: []const u8, allocator: std.mem.Allocator) !void {
|
||||
var manifest = try readManifest(path, allocator);
|
||||
defer manifest.deinit(allocator);
|
||||
|
||||
manifest.synced = true;
|
||||
try writeManifest(manifest, path, allocator);
|
||||
}
|
||||
|
||||
/// Build manifest path from experiment and run_id
|
||||
pub fn buildManifestPath(artifact_path: []const u8, experiment: []const u8, run_id: []const u8, allocator: std.mem.Allocator) ![]const u8 {
|
||||
return std.fs.path.join(allocator, &[_][]const u8{
|
||||
artifact_path,
|
||||
experiment,
|
||||
run_id,
|
||||
"run_manifest.json",
|
||||
});
|
||||
}
|
||||
|
||||
/// Resolve manifest path from input (path, run_id, or task_id)
|
||||
pub fn resolveManifestPath(input: []const u8, base_path: ?[]const u8, allocator: std.mem.Allocator) ![]const u8 {
|
||||
// If input is a valid file path, use it directly
|
||||
if (std.fs.path.isAbsolute(input)) {
|
||||
if (std.fs.cwd().access(input, .{})) {
|
||||
// It's a file or directory
|
||||
const stat = std.fs.cwd().statFile(input) catch {
|
||||
// It's a directory, append manifest name
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
};
|
||||
_ = stat;
|
||||
// It's a file, use as-is
|
||||
return try allocator.dupe(u8, input);
|
||||
} else |_| {}
|
||||
}
|
||||
|
||||
// Try relative path
|
||||
if (std.fs.cwd().access(input, .{})) {
|
||||
const stat = std.fs.cwd().statFile(input) catch {
|
||||
return std.fs.path.join(allocator, &[_][]const u8{ input, "run_manifest.json" });
|
||||
};
|
||||
_ = stat;
|
||||
return try allocator.dupe(u8, input);
|
||||
} else |_| {}
|
||||
|
||||
// Search by run_id in base_path
|
||||
if (base_path) |bp| {
|
||||
return try findManifestById(bp, input, allocator);
|
||||
}
|
||||
|
||||
return error.ManifestNotFound;
|
||||
}
|
||||
|
||||
/// Find manifest by run_id in base path
|
||||
fn findManifestById(base_path: []const u8, id: []const u8, allocator: std.mem.Allocator) ![]const u8 {
|
||||
// Look in experiments/ subdirectories
|
||||
var experiments_dir = std.fs.cwd().openDir(base_path, .{ .iterate = true }) catch {
|
||||
return error.ManifestNotFound;
|
||||
};
|
||||
defer experiments_dir.close();
|
||||
|
||||
var iter = experiments_dir.iterate();
|
||||
while (try iter.next()) |entry| {
|
||||
if (entry.kind != .directory) continue;
|
||||
|
||||
// Check if this experiment has a subdirectory matching the run_id
|
||||
const run_dir_path = try std.fs.path.join(allocator, &[_][]const u8{
|
||||
base_path,
|
||||
entry.name,
|
||||
id,
|
||||
});
|
||||
defer allocator.free(run_dir_path);
|
||||
|
||||
const manifest_path = try std.fs.path.join(allocator, &[_][]const u8{
|
||||
run_dir_path,
|
||||
"run_manifest.json",
|
||||
});
|
||||
|
||||
if (std.fs.cwd().access(manifest_path, .{})) {
|
||||
return manifest_path;
|
||||
} else |_| {
|
||||
allocator.free(manifest_path);
|
||||
}
|
||||
}
|
||||
|
||||
return error.ManifestNotFound;
|
||||
}
|
||||
108
cli/src/mode.zig
Normal file
108
cli/src/mode.zig
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
const std = @import("std");
|
||||
const Config = @import("config.zig").Config;
|
||||
const ws = @import("net/ws/client.zig");
|
||||
|
||||
/// Mode represents the operating mode of the CLI
|
||||
pub const Mode = enum {
|
||||
/// Local/offline mode - runs execute locally, tracking to SQLite
|
||||
offline,
|
||||
/// Online/runner mode - jobs queue to remote server
|
||||
online,
|
||||
};
|
||||
|
||||
/// DetectionResult includes the mode and any warning messages
|
||||
pub const DetectionResult = struct {
|
||||
mode: Mode,
|
||||
warning: ?[]const u8,
|
||||
};
|
||||
|
||||
/// Detect mode based on configuration and environment
|
||||
/// Priority order (CLI — checked on every command):
|
||||
/// 1. FETCHML_LOCAL=1 env var → local (forced, skip ping)
|
||||
/// 2. force_local=true in config → local (forced, skip ping)
|
||||
/// 3. cfg.Host == "" → local (not configured)
|
||||
/// 4. API ping within 2s timeout → runner mode
|
||||
/// - timeout / refused → local (fallback, log once per session)
|
||||
/// - 401/403 → local (fallback, warn once about auth)
|
||||
pub fn detect(allocator: std.mem.Allocator, cfg: Config) !DetectionResult {
|
||||
// Priority 1: FETCHML_LOCAL env var
|
||||
if (std.posix.getenv("FETCHML_LOCAL")) |val| {
|
||||
if (std.mem.eql(u8, val, "1")) {
|
||||
return .{ .mode = .offline, .warning = null };
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: force_local in config
|
||||
if (cfg.force_local) {
|
||||
return .{ .mode = .offline, .warning = null };
|
||||
}
|
||||
|
||||
// Priority 3: No host configured
|
||||
if (cfg.worker_host.len == 0) {
|
||||
return .{ .mode = .offline, .warning = null };
|
||||
}
|
||||
|
||||
// Priority 4: API ping with 2s timeout
|
||||
const ping_result = try pingServer(allocator, cfg, 2000);
|
||||
return switch (ping_result) {
|
||||
.success => .{ .mode = .online, .warning = null },
|
||||
.timeout => .{ .mode = .offline, .warning = "Server unreachable, falling back to local mode" },
|
||||
.refused => .{ .mode = .offline, .warning = "Server connection refused, falling back to local mode" },
|
||||
.auth_error => .{ .mode = .offline, .warning = "Authentication failed, falling back to local mode" },
|
||||
};
|
||||
}
|
||||
|
||||
/// PingResult represents the outcome of a server ping
|
||||
const PingResult = enum {
|
||||
success,
|
||||
timeout,
|
||||
refused,
|
||||
auth_error,
|
||||
};
|
||||
|
||||
/// Ping the server with a timeout - simplified version that just tries to connect
|
||||
fn pingServer(allocator: std.mem.Allocator, cfg: Config, timeout_ms: u64) !PingResult {
|
||||
_ = timeout_ms; // Timeout not implemented for this simplified version
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var connection = ws.Client.connect(allocator, ws_url, cfg.api_key) catch |err| {
|
||||
switch (err) {
|
||||
error.ConnectionTimedOut => return .timeout,
|
||||
error.ConnectionRefused => return .refused,
|
||||
error.AuthenticationFailed => return .auth_error,
|
||||
else => return .refused,
|
||||
}
|
||||
};
|
||||
defer connection.close();
|
||||
|
||||
// Try to receive any message to confirm server is responding
|
||||
const response = connection.receiveMessage(allocator) catch |err| {
|
||||
switch (err) {
|
||||
error.ConnectionTimedOut => return .timeout,
|
||||
else => return .refused,
|
||||
}
|
||||
};
|
||||
defer allocator.free(response);
|
||||
|
||||
return .success;
|
||||
}
|
||||
|
||||
/// Check if mode is online
|
||||
pub fn isOnline(mode: Mode) bool {
|
||||
return mode == .online;
|
||||
}
|
||||
|
||||
/// Check if mode is offline
|
||||
pub fn isOffline(mode: Mode) bool {
|
||||
return mode == .offline;
|
||||
}
|
||||
|
||||
/// Require online mode, returning error if offline
|
||||
pub fn requireOnline(mode: Mode, command_name: []const u8) !void {
|
||||
if (mode == .offline) {
|
||||
std.log.err("{s} requires server connection", .{command_name});
|
||||
return error.RequiresServer;
|
||||
}
|
||||
}
|
||||
71
cli/src/native/hash.zig
Normal file
71
cli/src/native/hash.zig
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
const std = @import("std");
|
||||
const c = @cImport({
|
||||
@cInclude("dataset_hash.h");
|
||||
});
|
||||
|
||||
pub const HashError = error{
|
||||
ContextInitFailed,
|
||||
HashFailed,
|
||||
InvalidPath,
|
||||
OutOfMemory,
|
||||
};
|
||||
|
||||
// Global context for reuse across multiple hash operations
|
||||
var global_ctx: ?*c.fh_context_t = null;
|
||||
var ctx_initialized = std.atomic.Value(bool).init(false);
|
||||
var init_mutex = std.Thread.Mutex{};
|
||||
|
||||
/// Initialize global hash context once (thread-safe)
|
||||
pub fn init() !void {
|
||||
if (ctx_initialized.load(.seq_cst)) return;
|
||||
|
||||
init_mutex.lock();
|
||||
defer init_mutex.unlock();
|
||||
|
||||
if (ctx_initialized.load(.seq_cst)) return; // Double-check
|
||||
|
||||
const start = std.time.milliTimestamp();
|
||||
global_ctx = c.fh_init(0); // 0 = auto-detect threads
|
||||
const elapsed = std.time.milliTimestamp() - start;
|
||||
|
||||
if (global_ctx == null) {
|
||||
return HashError.ContextInitFailed;
|
||||
}
|
||||
|
||||
ctx_initialized.store(true, .seq_cst);
|
||||
std.log.info("[native] hash context initialized: {}ms", .{elapsed});
|
||||
}
|
||||
|
||||
/// Hash a directory using the native library (reuses global context)
|
||||
/// Returns the hex-encoded SHA256 hash string
|
||||
pub fn hashDirectory(allocator: std.mem.Allocator, path: []const u8) ![]const u8 {
|
||||
try init(); // Idempotent initialization
|
||||
|
||||
const ctx = global_ctx.?; // Safe: init() guarantees non-null
|
||||
|
||||
// Convert path to null-terminated C string
|
||||
const c_path = try allocator.dupeZ(u8, path);
|
||||
defer allocator.free(c_path);
|
||||
|
||||
// Call native function
|
||||
const result = c.fh_hash_directory_combined(ctx, c_path);
|
||||
if (result == null) {
|
||||
return HashError.HashFailed;
|
||||
}
|
||||
defer c.fh_free_string(result);
|
||||
|
||||
// Convert result to Zig string
|
||||
const result_slice = std.mem.span(result);
|
||||
return try allocator.dupe(u8, result_slice);
|
||||
}
|
||||
|
||||
/// Check if SIMD SHA256 is available
|
||||
pub fn hasSimdSha256() bool {
|
||||
return c.fh_has_simd_sha256() == 1;
|
||||
}
|
||||
|
||||
/// Get the name of the SIMD implementation being used
|
||||
pub fn getSimdImplName() []const u8 {
|
||||
const name = c.fh_get_simd_impl_name();
|
||||
return std.mem.span(name);
|
||||
}
|
||||
262
cli/src/native/macos_gpu.zig
Normal file
262
cli/src/native/macos_gpu.zig
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
/// macOS GPU Monitoring for Development Mode
|
||||
/// Uses system_profiler and powermetrics for GPU info
|
||||
/// Only available on macOS
|
||||
const c = @cImport({
|
||||
@cInclude("sys/types.h");
|
||||
@cInclude("sys/sysctl.h");
|
||||
});
|
||||
|
||||
/// GPU information structure for macOS
|
||||
pub const MacOSGPUInfo = struct {
|
||||
index: u32,
|
||||
name: [256:0]u8,
|
||||
chipset_model: [256:0]u8,
|
||||
vram_mb: u32,
|
||||
is_integrated: bool,
|
||||
// Performance metrics (if available via powermetrics)
|
||||
utilization_percent: ?u32,
|
||||
temperature_celsius: ?u32,
|
||||
power_mw: ?u32,
|
||||
};
|
||||
|
||||
/// Detect if running on Apple Silicon
|
||||
pub fn isAppleSilicon() bool {
|
||||
if (builtin.os.tag != .macos) return false;
|
||||
|
||||
var buf: [64]u8 = undefined;
|
||||
var len: usize = buf.len;
|
||||
const mib = [_]c_int{ c.CTL_HW, c.HW_MACHINE };
|
||||
|
||||
const result = c.sysctl(&mib[0], 2, &buf[0], &len, null, 0);
|
||||
if (result != 0) return false;
|
||||
|
||||
const machine = std.mem.sliceTo(&buf, 0);
|
||||
return std.mem.startsWith(u8, machine, "arm64") or
|
||||
std.mem.startsWith(u8, machine, "Apple");
|
||||
}
|
||||
|
||||
/// Get GPU count on macOS
|
||||
pub fn getGPUCount() u32 {
|
||||
if (builtin.os.tag != .macos) return 0;
|
||||
|
||||
// Run system_profiler to check for GPUs
|
||||
const result = runSystemProfiler() catch return 0;
|
||||
defer std.heap.raw_c_allocator.free(result);
|
||||
|
||||
// Parse output for GPU entries
|
||||
var lines = std.mem.splitScalar(u8, result, '\n');
|
||||
var count: u32 = 0;
|
||||
while (lines.next()) |line| {
|
||||
if (std.mem.indexOf(u8, line, "Chipset Model") != null) {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
/// Run system_profiler SPDisplaysDataType
|
||||
fn runSystemProfiler() ![]u8 {
|
||||
const argv = [_][]const u8{
|
||||
"system_profiler",
|
||||
"SPDisplaysDataType",
|
||||
"-json",
|
||||
};
|
||||
|
||||
var child = std.process.Child.init(&argv, std.heap.page_allocator);
|
||||
child.stdout_behavior = .Pipe;
|
||||
child.stderr_behavior = .Ignore;
|
||||
|
||||
try child.spawn();
|
||||
defer child.kill() catch {};
|
||||
|
||||
const stdout = child.stdout.?.reader();
|
||||
const output = try stdout.readAllAlloc(std.heap.page_allocator, 1024 * 1024);
|
||||
|
||||
const term = try child.wait();
|
||||
if (term != .Exited or term.Exited != 0) {
|
||||
return error.CommandFailed;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
/// Parse GPU info from system_profiler JSON output
|
||||
pub fn parseGPUInfo(allocator: std.mem.Allocator, json_output: []const u8) ![]MacOSGPUInfo {
|
||||
// Simple parser for system_profiler JSON
|
||||
// Format: {"SPDisplaysDataType": [{"sppci_model":"...", "sppci_vram":"...", ...}, ...]}
|
||||
|
||||
var gpus = std.ArrayList(MacOSGPUInfo).init(allocator);
|
||||
defer gpus.deinit();
|
||||
|
||||
// Parse JSON - look for _items array
|
||||
const items_key = "_items";
|
||||
if (std.mem.indexOf(u8, json_output, items_key)) |items_start| {
|
||||
const rest = json_output[items_start..];
|
||||
// Find array start
|
||||
if (std.mem.indexOf(u8, rest, "[")) |array_start| {
|
||||
const array = rest[array_start..];
|
||||
// Simple heuristic: find objects between { and }
|
||||
var i: usize = 0;
|
||||
while (i < array.len) {
|
||||
if (array[i] == '{') {
|
||||
// Found object start
|
||||
if (findObjectEnd(array[i..])) |obj_end| {
|
||||
const obj = array[i .. i + obj_end];
|
||||
if (try parseGPUObject(obj)) |gpu| {
|
||||
try gpus.append(gpu);
|
||||
}
|
||||
i += obj_end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return gpus.toOwnedSlice();
|
||||
}
|
||||
|
||||
fn findObjectEnd(json: []const u8) ?usize {
|
||||
var depth: i32 = 0;
|
||||
var in_string = false;
|
||||
var i: usize = 0;
|
||||
while (i < json.len) : (i += 1) {
|
||||
const char = json[i];
|
||||
if (char == '"' and (i == 0 or json[i - 1] != '\\')) {
|
||||
in_string = !in_string;
|
||||
} else if (!in_string) {
|
||||
if (char == '{') {
|
||||
depth += 1;
|
||||
} else if (char == '}') {
|
||||
depth -= 1;
|
||||
if (depth == 0) {
|
||||
return i + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
fn parseGPUObject(json: []const u8) !?MacOSGPUInfo {
|
||||
var gpu = MacOSGPUInfo{
|
||||
.index = 0,
|
||||
.name = std.mem.zeroes([256:0]u8),
|
||||
.chipset_model = std.mem.zeroes([256:0]u8),
|
||||
.vram_mb = 0,
|
||||
.is_integrated = false,
|
||||
.utilization_percent = null,
|
||||
.temperature_celsius = null,
|
||||
.power_mw = null,
|
||||
};
|
||||
|
||||
// Extract sppci_model
|
||||
if (extractJsonString(json, "sppci_model")) |model| {
|
||||
const len = @min(model.len, 255);
|
||||
@memcpy(gpu.chipset_model[0..len], model[0..len]);
|
||||
@memcpy(gpu.name[0..len], model[0..len]);
|
||||
}
|
||||
|
||||
// Extract sppci_vram
|
||||
if (extractJsonString(json, "sppci_vram_shared")) |_| {
|
||||
gpu.is_integrated = true;
|
||||
gpu.vram_mb = 0; // Shared memory
|
||||
} else if (extractJsonString(json, "sppci_vram")) |vram| {
|
||||
// Parse "16384 MB" -> 16384
|
||||
var it = std.mem.splitScalar(u8, vram, ' ');
|
||||
if (it.next()) |num_str| {
|
||||
gpu.vram_mb = std.fmt.parseInt(u32, num_str, 10) catch 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a valid GPU entry
|
||||
if (gpu.chipset_model[0] == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return gpu;
|
||||
}
|
||||
|
||||
fn extractJsonString(json: []const u8, key: []const u8) ?[]const u8 {
|
||||
const key_quoted = std.fmt.allocPrint(std.heap.page_allocator, "\"{s}\"", .{key}) catch return null;
|
||||
defer std.heap.page_allocator.free(key_quoted);
|
||||
|
||||
if (std.mem.indexOf(u8, json, key_quoted)) |key_pos| {
|
||||
const after_key = json[key_pos + key_quoted.len ..];
|
||||
// Find value start (skip : and whitespace)
|
||||
var i: usize = 0;
|
||||
while (i < after_key.len and (after_key[i] == ':' or after_key[i] == ' ' or after_key[i] == '\t' or after_key[i] == '\n')) : (i += 1) {}
|
||||
|
||||
if (i < after_key.len and after_key[i] == '"') {
|
||||
// String value
|
||||
const str_start = i + 1;
|
||||
var str_end = str_start;
|
||||
while (str_end < after_key.len and after_key[str_end] != '"') : (str_end += 1) {}
|
||||
return after_key[str_start..str_end];
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Format GPU info for display
|
||||
pub fn formatMacOSGPUInfo(allocator: std.mem.Allocator, gpus: []const MacOSGPUInfo) ![]u8 {
|
||||
var buf = std.ArrayList(u8).init(allocator);
|
||||
defer buf.deinit();
|
||||
|
||||
const writer = buf.writer();
|
||||
|
||||
if (gpus.len == 0) {
|
||||
try writer.writeAll("GPU Status (macOS)\n");
|
||||
try writer.writeAll("═" ** 50);
|
||||
try writer.writeAll("\n\nNo GPUs detected\n");
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
||||
try writer.writeAll("GPU Status (macOS");
|
||||
if (isAppleSilicon()) {
|
||||
try writer.writeAll(" - Apple Silicon");
|
||||
}
|
||||
try writer.writeAll(")\n");
|
||||
try writer.writeAll("═" ** 50);
|
||||
try writer.writeAll("\n\n");
|
||||
|
||||
for (gpus) |gpu| {
|
||||
const name = std.mem.sliceTo(&gpu.name, 0);
|
||||
const model = std.mem.sliceTo(&gpu.chipset_model, 0);
|
||||
|
||||
try writer.print("🎮 GPU {d}: {s}\n", .{ gpu.index, name });
|
||||
if (!std.mem.eql(u8, model, name)) {
|
||||
try writer.print(" Model: {s}\n", .{model});
|
||||
}
|
||||
if (gpu.is_integrated) {
|
||||
try writer.writeAll(" Type: Integrated (Unified Memory)\n");
|
||||
} else {
|
||||
try writer.print(" VRAM: {d} MB\n", .{gpu.vram_mb});
|
||||
}
|
||||
if (gpu.utilization_percent) |util| {
|
||||
try writer.print(" Utilization: {d}%\n", .{util});
|
||||
}
|
||||
if (gpu.temperature_celsius) |temp| {
|
||||
try writer.print(" Temperature: {d}°C\n", .{temp});
|
||||
}
|
||||
if (gpu.power_mw) |power| {
|
||||
try writer.print(" Power: {d:.1f} W\n", .{@as(f64, @floatFromInt(power)) / 1000.0});
|
||||
}
|
||||
try writer.writeAll("\n");
|
||||
}
|
||||
|
||||
try writer.writeAll("💡 Note: Detailed GPU metrics require powermetrics (sudo)\n");
|
||||
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
||||
/// Quick check for GPU availability on macOS
|
||||
pub fn isMacOSGPUAvailable() bool {
|
||||
if (builtin.os.tag != .macos) return false;
|
||||
return getGPUCount() > 0;
|
||||
}
|
||||
372
cli/src/native/nvml.zig
Normal file
372
cli/src/native/nvml.zig
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
/// NVML Dynamic Loader for CLI
|
||||
/// Pure Zig implementation using dlopen/LoadLibrary
|
||||
/// No build-time dependency on NVIDIA SDK
|
||||
|
||||
// Platform-specific dynamic loading
|
||||
const DynLib = switch (builtin.os.tag) {
|
||||
.windows => struct {
|
||||
handle: std.os.windows.HMODULE,
|
||||
|
||||
fn open(path: []const u8) !@This() {
|
||||
const wide_path = try std.os.windows.sliceToPrefixedFileW(path);
|
||||
const handle = std.os.windows.LoadLibraryW(&wide_path.data) orelse return error.LibraryNotFound;
|
||||
return .{ .handle = handle };
|
||||
}
|
||||
|
||||
fn close(self: *@This()) void {
|
||||
_ = std.os.windows.FreeLibrary(self.handle);
|
||||
}
|
||||
|
||||
fn lookup(self: @This(), name: []const u8) ?*anyopaque {
|
||||
return std.os.windows.GetProcAddress(self.handle, name);
|
||||
}
|
||||
},
|
||||
else => struct {
|
||||
handle: *anyopaque,
|
||||
|
||||
// Extern declarations for dlopen/dlsym
|
||||
extern "c" fn dlopen(pathname: [*:0]const u8, mode: c_int) ?*anyopaque;
|
||||
extern "c" fn dlsym(handle: *anyopaque, symbol: [*:0]const u8) ?*anyopaque;
|
||||
extern "c" fn dlclose(handle: *anyopaque) c_int;
|
||||
|
||||
const RTLD_NOW = 2;
|
||||
|
||||
fn open(path: []const u8) !@This() {
|
||||
const c_path = try std.cstr.addNullByte(std.heap.c_allocator, path);
|
||||
defer std.heap.c_allocator.free(c_path);
|
||||
const handle = dlopen(c_path.ptr, RTLD_NOW) orelse return error.LibraryNotFound;
|
||||
return .{ .handle = handle };
|
||||
}
|
||||
|
||||
fn close(self: *@This()) void {
|
||||
_ = dlclose(self.handle);
|
||||
}
|
||||
|
||||
fn lookup(self: @This(), name: []const u8) ?*anyopaque {
|
||||
const c_name = std.cstr.addNullByte(std.heap.c_allocator, name) catch return null;
|
||||
defer std.heap.c_allocator.free(c_name);
|
||||
return dlsym(self.handle, c_name.ptr);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// NVML type definitions (mirrors nvml.h)
|
||||
pub const nvmlReturn_t = c_int;
|
||||
pub const nvmlDevice_t = *anyopaque;
|
||||
|
||||
pub const nvmlUtilization_t = extern struct {
|
||||
gpu: c_uint,
|
||||
memory: c_uint,
|
||||
};
|
||||
|
||||
pub const nvmlMemory_t = extern struct {
|
||||
total: c_ulonglong,
|
||||
free: c_ulonglong,
|
||||
used: c_ulonglong,
|
||||
};
|
||||
|
||||
// NVML constants
|
||||
const NVML_SUCCESS = 0;
|
||||
const NVML_TEMPERATURE_GPU = 0;
|
||||
const NVML_CLOCK_SM = 0;
|
||||
const NVML_CLOCK_MEM = 1;
|
||||
|
||||
// NVML function types
|
||||
const nvmlInit_v2_fn = *const fn () callconv(.C) nvmlReturn_t;
|
||||
const nvmlShutdown_fn = *const fn () callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetCount_fn = *const fn (*c_uint) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetHandleByIndex_v2_fn = *const fn (c_uint, *nvmlDevice_t) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetName_fn = *const fn (nvmlDevice_t, [*]u8, c_uint) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetUtilizationRates_fn = *const fn (nvmlDevice_t, *nvmlUtilization_t) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetMemoryInfo_fn = *const fn (nvmlDevice_t, *nvmlMemory_t) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetTemperature_fn = *const fn (nvmlDevice_t, c_uint, *c_uint) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetPowerUsage_fn = *const fn (nvmlDevice_t, *c_uint) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetClockInfo_fn = *const fn (nvmlDevice_t, c_uint, *c_uint) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetUUID_fn = *const fn (nvmlDevice_t, [*]u8, c_uint) callconv(.C) nvmlReturn_t;
|
||||
const nvmlDeviceGetVbiosVersion_fn = *const fn (nvmlDevice_t, [*]u8, c_uint) callconv(.C) nvmlReturn_t;
|
||||
|
||||
/// GPU information structure
|
||||
pub const GPUInfo = struct {
|
||||
index: u32,
|
||||
name: [256:0]u8,
|
||||
utilization: u32,
|
||||
memory_used: u64,
|
||||
memory_total: u64,
|
||||
temperature: u32,
|
||||
power_draw: u32,
|
||||
clock_sm: u32,
|
||||
clock_memory: u32,
|
||||
uuid: [64:0]u8,
|
||||
vbios_version: [32:0]u8,
|
||||
};
|
||||
|
||||
/// NVML handle with loaded functions
|
||||
pub const NVML = struct {
|
||||
lib: DynLib,
|
||||
available: bool,
|
||||
|
||||
// Function pointers
|
||||
init: nvmlInit_v2_fn,
|
||||
shutdown: nvmlShutdown_fn,
|
||||
get_count: nvmlDeviceGetCount_fn,
|
||||
get_handle_by_index: nvmlDeviceGetHandleByIndex_v2_fn,
|
||||
get_name: ?nvmlDeviceGetName_fn,
|
||||
get_utilization: ?nvmlDeviceGetUtilizationRates_fn,
|
||||
get_memory: ?nvmlDeviceGetMemoryInfo_fn,
|
||||
get_temperature: ?nvmlDeviceGetTemperature_fn,
|
||||
get_power_usage: ?nvmlDeviceGetPowerUsage_fn,
|
||||
get_clock: ?nvmlDeviceGetClockInfo_fn,
|
||||
get_uuid: ?nvmlDeviceGetUUID_fn,
|
||||
get_vbios: ?nvmlDeviceGetVbiosVersion_fn,
|
||||
|
||||
last_error: [256:0]u8,
|
||||
|
||||
/// Load NVML dynamically
|
||||
pub fn load() !?NVML {
|
||||
var nvml: NVML = undefined;
|
||||
|
||||
// Try platform-specific library names
|
||||
const lib_names = switch (builtin.os.tag) {
|
||||
.windows => &[_][]const u8{
|
||||
"nvml.dll",
|
||||
"C:\\Windows\\System32\\nvml.dll",
|
||||
},
|
||||
.linux => &[_][]const u8{
|
||||
"libnvidia-ml.so.1",
|
||||
"libnvidia-ml.so",
|
||||
},
|
||||
else => return null, // NVML not supported on other platforms
|
||||
};
|
||||
|
||||
// Try to load library
|
||||
var loaded = false;
|
||||
for (lib_names) |name| {
|
||||
if (DynLib.open(name)) |lib| {
|
||||
nvml.lib = lib;
|
||||
loaded = true;
|
||||
break;
|
||||
} else |_| continue;
|
||||
}
|
||||
|
||||
if (!loaded) {
|
||||
return null; // NVML not available (no NVIDIA driver)
|
||||
}
|
||||
|
||||
// Load required functions
|
||||
nvml.init = @ptrCast(nvml.lib.lookup("nvmlInit_v2") orelse return error.InitNotFound);
|
||||
nvml.shutdown = @ptrCast(nvml.lib.lookup("nvmlShutdown") orelse return error.ShutdownNotFound);
|
||||
nvml.get_count = @ptrCast(nvml.lib.lookup("nvmlDeviceGetCount") orelse return error.GetCountNotFound);
|
||||
nvml.get_handle_by_index = @ptrCast(nvml.lib.lookup("nvmlDeviceGetHandleByIndex_v2") orelse return error.GetHandleNotFound);
|
||||
|
||||
// Load optional functions
|
||||
nvml.get_name = @ptrCast(nvml.lib.lookup("nvmlDeviceGetName"));
|
||||
nvml.get_utilization = @ptrCast(nvml.lib.lookup("nvmlDeviceGetUtilizationRates"));
|
||||
nvml.get_memory = @ptrCast(nvml.lib.lookup("nvmlDeviceGetMemoryInfo"));
|
||||
nvml.get_temperature = @ptrCast(nvml.lib.lookup("nvmlDeviceGetTemperature"));
|
||||
nvml.get_power_usage = @ptrCast(nvml.lib.lookup("nvmlDeviceGetPowerUsage"));
|
||||
nvml.get_clock = @ptrCast(nvml.lib.lookup("nvmlDeviceGetClockInfo"));
|
||||
nvml.get_uuid = @ptrCast(nvml.lib.lookup("nvmlDeviceGetUUID"));
|
||||
nvml.get_vbios = @ptrCast(nvml.lib.lookup("nvmlDeviceGetVbiosVersion"));
|
||||
|
||||
// Initialize NVML
|
||||
const result = nvml.init();
|
||||
if (result != NVML_SUCCESS) {
|
||||
nvml.setError("NVML initialization failed");
|
||||
nvml.lib.close();
|
||||
return error.NVMLInitFailed;
|
||||
}
|
||||
|
||||
nvml.available = true;
|
||||
return nvml;
|
||||
}
|
||||
|
||||
/// Unload NVML
|
||||
pub fn unload(self: *NVML) void {
|
||||
if (self.available) {
|
||||
_ = self.shutdown();
|
||||
}
|
||||
self.lib.close();
|
||||
}
|
||||
|
||||
/// Check if NVML is available
|
||||
pub fn isAvailable(self: NVML) bool {
|
||||
return self.available;
|
||||
}
|
||||
|
||||
/// Get last error message
|
||||
pub fn getLastError(self: NVML) []const u8 {
|
||||
return std.mem.sliceTo(&self.last_error, 0);
|
||||
}
|
||||
|
||||
fn setError(self: *NVML, msg: []const u8) void {
|
||||
@memset(&self.last_error, 0);
|
||||
const len = @min(msg.len, self.last_error.len - 1);
|
||||
@memcpy(self.last_error[0..len], msg[0..len]);
|
||||
}
|
||||
|
||||
/// Get number of GPUs
|
||||
pub fn getGPUCount(self: *NVML) !u32 {
|
||||
var count: c_uint = 0;
|
||||
const result = self.get_count(&count);
|
||||
if (result != NVML_SUCCESS) {
|
||||
self.setError("Failed to get GPU count");
|
||||
return error.GetCountFailed;
|
||||
}
|
||||
return @intCast(count);
|
||||
}
|
||||
|
||||
/// Get GPU info by index
|
||||
pub fn getGPUInfo(self: *NVML, index: u32) !GPUInfo {
|
||||
var info: GPUInfo = .{
|
||||
.index = index,
|
||||
.name = std.mem.zeroes([256:0]u8),
|
||||
.utilization = 0,
|
||||
.memory_used = 0,
|
||||
.memory_total = 0,
|
||||
.temperature = 0,
|
||||
.power_draw = 0,
|
||||
.clock_sm = 0,
|
||||
.clock_memory = 0,
|
||||
.uuid = std.mem.zeroes([64:0]u8),
|
||||
.vbios_version = std.mem.zeroes([32:0]u8),
|
||||
};
|
||||
|
||||
var device: nvmlDevice_t = undefined;
|
||||
var result = self.get_handle_by_index(index, &device);
|
||||
if (result != NVML_SUCCESS) {
|
||||
self.setError("Failed to get device handle");
|
||||
return error.GetHandleFailed;
|
||||
}
|
||||
|
||||
// Get name
|
||||
if (self.get_name) |func| {
|
||||
_ = func(device, &info.name, @sizeOf(@TypeOf(info.name)));
|
||||
}
|
||||
|
||||
// Get utilization
|
||||
if (self.get_utilization) |func| {
|
||||
var util: nvmlUtilization_t = undefined;
|
||||
result = func(device, &util);
|
||||
if (result == NVML_SUCCESS) {
|
||||
info.utilization = @intCast(util.gpu);
|
||||
}
|
||||
}
|
||||
|
||||
// Get memory
|
||||
if (self.get_memory) |func| {
|
||||
var mem: nvmlMemory_t = undefined;
|
||||
result = func(device, &mem);
|
||||
if (result == NVML_SUCCESS) {
|
||||
info.memory_used = mem.used;
|
||||
info.memory_total = mem.total;
|
||||
}
|
||||
}
|
||||
|
||||
// Get temperature
|
||||
if (self.get_temperature) |func| {
|
||||
var temp: c_uint = 0;
|
||||
result = func(device, NVML_TEMPERATURE_GPU, &temp);
|
||||
if (result == NVML_SUCCESS) {
|
||||
info.temperature = @intCast(temp);
|
||||
}
|
||||
}
|
||||
|
||||
// Get power usage
|
||||
if (self.get_power_usage) |func| {
|
||||
var power: c_uint = 0;
|
||||
result = func(device, &power);
|
||||
if (result == NVML_SUCCESS) {
|
||||
info.power_draw = @intCast(power);
|
||||
}
|
||||
}
|
||||
|
||||
// Get clocks
|
||||
if (self.get_clock) |func| {
|
||||
var clock: c_uint = 0;
|
||||
result = func(device, NVML_CLOCK_SM, &clock);
|
||||
if (result == NVML_SUCCESS) {
|
||||
info.clock_sm = @intCast(clock);
|
||||
}
|
||||
result = func(device, NVML_CLOCK_MEM, &clock);
|
||||
if (result == NVML_SUCCESS) {
|
||||
info.clock_memory = @intCast(clock);
|
||||
}
|
||||
}
|
||||
|
||||
// Get UUID
|
||||
if (self.get_uuid) |func| {
|
||||
_ = func(device, &info.uuid, @sizeOf(@TypeOf(info.uuid)));
|
||||
}
|
||||
|
||||
// Get VBIOS version
|
||||
if (self.get_vbios) |func| {
|
||||
_ = func(device, &info.vbios_version, @sizeOf(@TypeOf(info.vbios_version)));
|
||||
}
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
/// Get info for all GPUs
|
||||
pub fn getAllGPUInfo(self: *NVML, allocator: std.mem.Allocator) ![]GPUInfo {
|
||||
const count = try self.getGPUCount();
|
||||
if (count == 0) return &[_]GPUInfo{};
|
||||
|
||||
var gpus = try allocator.alloc(GPUInfo, count);
|
||||
errdefer allocator.free(gpus);
|
||||
|
||||
for (0..count) |i| {
|
||||
gpus[i] = try self.getGPUInfo(@intCast(i));
|
||||
}
|
||||
|
||||
return gpus;
|
||||
}
|
||||
};
|
||||
|
||||
// Convenience functions for simple use cases
|
||||
|
||||
/// Quick check if NVML is available (creates and destroys temporary handle)
|
||||
pub fn isNVMLAvailable() bool {
|
||||
if (NVML.load()) |maybe_nvml| {
|
||||
if (maybe_nvml) |nvml| {
|
||||
var nvml_mut = nvml;
|
||||
defer nvml_mut.unload();
|
||||
return nvml_mut.isAvailable();
|
||||
}
|
||||
} else |_| {}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Format GPU info as string for display
|
||||
pub fn formatGPUInfo(allocator: std.mem.Allocator, gpus: []const GPUInfo) ![]u8 {
|
||||
var buf = std.ArrayList(u8).init(allocator);
|
||||
defer buf.deinit();
|
||||
|
||||
const writer = buf.writer();
|
||||
|
||||
try writer.writeAll("GPU Status (NVML)\n");
|
||||
try writer.writeAll("═" ** 50);
|
||||
try writer.writeAll("\n\n");
|
||||
|
||||
for (gpus) |gpu| {
|
||||
const name = std.mem.sliceTo(&gpu.name, 0);
|
||||
try writer.print("🎮 GPU {d}: {s}\n", .{ gpu.index, name });
|
||||
try writer.print(" Utilization: {d}%\n", .{gpu.utilization});
|
||||
try writer.print(" Memory: {d}/{d} MB\n", .{
|
||||
gpu.memory_used / 1024 / 1024,
|
||||
gpu.memory_total / 1024 / 1024,
|
||||
});
|
||||
try writer.print(" Temperature: {d}°C\n", .{gpu.temperature});
|
||||
if (gpu.power_draw > 0) {
|
||||
try writer.print(" Power: {d:.1} W\n", .{@as(f64, @floatFromInt(gpu.power_draw)) / 1000.0});
|
||||
}
|
||||
if (gpu.clock_sm > 0) {
|
||||
try writer.print(" SM Clock: {d} MHz\n", .{gpu.clock_sm});
|
||||
}
|
||||
try writer.writeAll("\n");
|
||||
}
|
||||
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
|
@ -1,9 +1,8 @@
|
|||
const deps = @import("deps.zig");
|
||||
const std = deps.std;
|
||||
const crypto = deps.crypto;
|
||||
const io = deps.io;
|
||||
const log = deps.log;
|
||||
const protocol = deps.protocol;
|
||||
const std = @import("std");
|
||||
const crypto = @import("crypto");
|
||||
const io = @import("io");
|
||||
const log = @import("log");
|
||||
const protocol = @import("../protocol.zig");
|
||||
const resolve = @import("resolve.zig");
|
||||
const handshake = @import("handshake.zig");
|
||||
const frame = @import("frame.zig");
|
||||
|
|
@ -220,6 +219,36 @@ pub const Client = struct {
|
|||
try builder.send(stream);
|
||||
}
|
||||
|
||||
pub fn sendQueryJobByCommit(self: *Client, job_name: []const u8, commit_id: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
try validateCommitId(commit_id);
|
||||
try validateJobName(job_name);
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [job_name_len: u8] [job_name: var] [commit_id: 20 bytes]
|
||||
const total_len = 1 + 16 + 1 + job_name.len + 20;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.query_job);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(job_name.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + job_name.len], job_name);
|
||||
offset += job_name.len;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 20], commit_id);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendListJupyterPackages(self: *Client, name: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
|
@ -272,6 +301,44 @@ pub const Client = struct {
|
|||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendSetRunPrivacy(
|
||||
self: *Client,
|
||||
job_name: []const u8,
|
||||
patch_json: []const u8,
|
||||
api_key_hash: []const u8,
|
||||
) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (job_name.len == 0 or job_name.len > 255) return error.JobNameTooLong;
|
||||
if (patch_json.len == 0 or patch_json.len > 0xFFFF) return error.PayloadTooLarge;
|
||||
|
||||
// [opcode]
|
||||
// [api_key_hash:16]
|
||||
// [job_name_len:1][job_name]
|
||||
// [patch_len:2][patch_json]
|
||||
const total_len = 1 + 16 + 1 + job_name.len + 2 + patch_json.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.set_run_privacy);
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @as(u8, @intCast(job_name.len));
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + job_name.len], job_name);
|
||||
offset += job_name.len;
|
||||
|
||||
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @as(u16, @intCast(patch_json.len)), .big);
|
||||
offset += 2;
|
||||
@memcpy(buffer[offset .. offset + patch_json.len], patch_json);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendAnnotateRun(
|
||||
self: *Client,
|
||||
job_name: []const u8,
|
||||
|
|
@ -857,6 +924,62 @@ pub const Client = struct {
|
|||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendSyncRun(self: *Client, sync_json: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (sync_json.len > 0xFFFF) return error.PayloadTooLarge;
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [json_len: u16] [json: var]
|
||||
const total_len = 1 + 16 + 2 + sync_json.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.sync_run);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(sync_json.len), .big);
|
||||
offset += 2;
|
||||
|
||||
if (sync_json.len > 0) {
|
||||
@memcpy(buffer[offset .. offset + sync_json.len], sync_json);
|
||||
}
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendRerunRequest(self: *Client, run_id: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (run_id.len > 255) return error.PayloadTooLarge;
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [run_id_len: u8] [run_id: var]
|
||||
const total_len = 1 + 16 + 1 + run_id.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.rerun_request);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(run_id.len);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + run_id.len], run_id);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendStatusRequest(self: *Client, api_key_hash: []const u8) !void {
|
||||
const stream = try self.getStream();
|
||||
try validateApiKeyHash(api_key_hash);
|
||||
|
|
@ -1259,6 +1382,81 @@ pub const Client = struct {
|
|||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendCreateExperiment(self: *Client, api_key_hash: []const u8, name: []const u8, description: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (name.len == 0 or name.len > 255) return error.NameTooLong;
|
||||
if (description.len > 1023) return error.DescriptionTooLong;
|
||||
|
||||
// Build binary message:
|
||||
// [opcode: u8] [api_key_hash: 16 bytes] [name_len: u8] [name: var] [desc_len: u16] [description: var]
|
||||
const total_len = 1 + 16 + 1 + name.len + 2 + description.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.create_experiment);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(name.len);
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + name.len], name);
|
||||
offset += name.len;
|
||||
|
||||
std.mem.writeInt(u16, buffer[offset .. offset + 2][0..2], @intCast(description.len), .big);
|
||||
offset += 2;
|
||||
if (description.len > 0) {
|
||||
@memcpy(buffer[offset .. offset + description.len], description);
|
||||
}
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendListExperiments(self: *Client, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
|
||||
// Build binary message: [opcode: u8] [api_key_hash: 16 bytes]
|
||||
const total_len = 1 + 16;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
buffer[0] = @intFromEnum(opcode.list_experiments);
|
||||
@memcpy(buffer[1..17], api_key_hash);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
pub fn sendGetExperimentByID(self: *Client, api_key_hash: []const u8, experiment_id: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
||||
if (api_key_hash.len != 16) return error.InvalidApiKeyHash;
|
||||
if (experiment_id.len == 0 or experiment_id.len > 255) return error.InvalidExperimentId;
|
||||
|
||||
// Build binary message: [opcode: u8] [api_key_hash: 16 bytes] [exp_id_len: u8] [experiment_id: var]
|
||||
const total_len = 1 + 16 + 1 + experiment_id.len;
|
||||
var buffer = try self.allocator.alloc(u8, total_len);
|
||||
defer self.allocator.free(buffer);
|
||||
|
||||
var offset: usize = 0;
|
||||
buffer[offset] = @intFromEnum(opcode.get_experiment);
|
||||
offset += 1;
|
||||
|
||||
@memcpy(buffer[offset .. offset + 16], api_key_hash);
|
||||
offset += 16;
|
||||
|
||||
buffer[offset] = @intCast(experiment_id.len);
|
||||
offset += 1;
|
||||
@memcpy(buffer[offset .. offset + experiment_id.len], experiment_id);
|
||||
|
||||
try frame.sendWebSocketFrame(stream, buffer);
|
||||
}
|
||||
|
||||
// Logs and debug methods
|
||||
pub fn sendGetLogs(self: *Client, target_id: []const u8, api_key_hash: []const u8) !void {
|
||||
const stream = self.stream orelse return error.NotConnected;
|
||||
|
|
|
|||
|
|
@ -6,12 +6,15 @@ pub const Opcode = enum(u8) {
|
|||
queue_job_with_note = 0x1B,
|
||||
annotate_run = 0x1C,
|
||||
set_run_narrative = 0x1D,
|
||||
set_run_privacy = 0x1F,
|
||||
status_request = 0x02,
|
||||
cancel_job = 0x03,
|
||||
prune = 0x04,
|
||||
crash_report = 0x05,
|
||||
log_metric = 0x0A,
|
||||
get_experiment = 0x0B,
|
||||
create_experiment = 0x24,
|
||||
list_experiments = 0x25,
|
||||
start_jupyter = 0x0D,
|
||||
stop_jupyter = 0x0E,
|
||||
remove_jupyter = 0x18,
|
||||
|
|
@ -21,6 +24,9 @@ pub const Opcode = enum(u8) {
|
|||
|
||||
validate_request = 0x16,
|
||||
|
||||
// Job query opcode
|
||||
query_job = 0x23,
|
||||
|
||||
// Logs and debug opcodes
|
||||
get_logs = 0x20,
|
||||
stream_logs = 0x21,
|
||||
|
|
@ -32,6 +38,12 @@ pub const Opcode = enum(u8) {
|
|||
dataset_info = 0x08,
|
||||
dataset_search = 0x09,
|
||||
|
||||
// Sync opcode
|
||||
sync_run = 0x26,
|
||||
|
||||
// Rerun opcode
|
||||
rerun_request = 0x27,
|
||||
|
||||
// Structured response opcodes
|
||||
response_success = 0x10,
|
||||
response_error = 0x11,
|
||||
|
|
@ -53,12 +65,15 @@ pub const queue_job_with_args = Opcode.queue_job_with_args;
|
|||
pub const queue_job_with_note = Opcode.queue_job_with_note;
|
||||
pub const annotate_run = Opcode.annotate_run;
|
||||
pub const set_run_narrative = Opcode.set_run_narrative;
|
||||
pub const set_run_privacy = Opcode.set_run_privacy;
|
||||
pub const status_request = Opcode.status_request;
|
||||
pub const cancel_job = Opcode.cancel_job;
|
||||
pub const prune = Opcode.prune;
|
||||
pub const crash_report = Opcode.crash_report;
|
||||
pub const log_metric = Opcode.log_metric;
|
||||
pub const get_experiment = Opcode.get_experiment;
|
||||
pub const create_experiment = Opcode.create_experiment;
|
||||
pub const list_experiments = Opcode.list_experiments;
|
||||
pub const start_jupyter = Opcode.start_jupyter;
|
||||
pub const stop_jupyter = Opcode.stop_jupyter;
|
||||
pub const remove_jupyter = Opcode.remove_jupyter;
|
||||
|
|
@ -66,6 +81,7 @@ pub const restore_jupyter = Opcode.restore_jupyter;
|
|||
pub const list_jupyter = Opcode.list_jupyter;
|
||||
pub const list_jupyter_packages = Opcode.list_jupyter_packages;
|
||||
pub const validate_request = Opcode.validate_request;
|
||||
pub const query_job = Opcode.query_job;
|
||||
pub const get_logs = Opcode.get_logs;
|
||||
pub const stream_logs = Opcode.stream_logs;
|
||||
pub const attach_debug = Opcode.attach_debug;
|
||||
|
|
@ -73,6 +89,8 @@ pub const dataset_list = Opcode.dataset_list;
|
|||
pub const dataset_register = Opcode.dataset_register;
|
||||
pub const dataset_info = Opcode.dataset_info;
|
||||
pub const dataset_search = Opcode.dataset_search;
|
||||
pub const sync_run = Opcode.sync_run;
|
||||
pub const rerun_request = Opcode.rerun_request;
|
||||
pub const response_success = Opcode.response_success;
|
||||
pub const response_error = Opcode.response_error;
|
||||
pub const response_progress = Opcode.response_progress;
|
||||
|
|
|
|||
|
|
@ -69,6 +69,9 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
colors.printInfo("Status retrieved for user: {s} (admin: {})\n", .{ name, admin });
|
||||
}
|
||||
|
||||
// Display system summary
|
||||
colors.printInfo("\n=== Queue Summary ===\n", .{});
|
||||
|
||||
// Display task summary
|
||||
if (root.get("tasks")) |tasks_obj| {
|
||||
const tasks = tasks_obj.object;
|
||||
|
|
@ -78,11 +81,18 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
const failed = tasks.get("failed").?.integer;
|
||||
const completed = tasks.get("completed").?.integer;
|
||||
colors.printInfo(
|
||||
"Tasks: {d} total | {d} queued | {d} running | {d} failed | {d} completed\n",
|
||||
"Total: {d} | Queued: {d} | Running: {d} | Failed: {d} | Completed: {d}\n",
|
||||
.{ total, queued, running, failed, completed },
|
||||
);
|
||||
}
|
||||
|
||||
// Display queue depth if available
|
||||
if (root.get("queue_length")) |ql| {
|
||||
if (ql == .integer) {
|
||||
colors.printInfo("Queue depth: {d}\n", .{ql.integer});
|
||||
}
|
||||
}
|
||||
|
||||
const per_section_limit: usize = options.limit orelse 5;
|
||||
|
||||
const TaskStatus = enum { queued, running, failed, completed };
|
||||
|
|
@ -120,43 +130,54 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
_ = allocator2;
|
||||
const label = statusLabel(status);
|
||||
const want = statusMatch(status);
|
||||
std.debug.print("\n{s}:\n", .{label});
|
||||
colors.printInfo("\n{s}:\n", .{label});
|
||||
|
||||
var shown: usize = 0;
|
||||
var position: usize = 0;
|
||||
for (queue_items) |item| {
|
||||
if (item != .object) continue;
|
||||
const obj = item.object;
|
||||
const st = utils.jsonGetString(obj, "status") orelse "";
|
||||
if (!std.mem.eql(u8, st, want)) continue;
|
||||
|
||||
position += 1;
|
||||
if (shown >= limit2) continue;
|
||||
|
||||
const id = utils.jsonGetString(obj, "id") orelse "";
|
||||
const job_name = utils.jsonGetString(obj, "job_name") orelse "";
|
||||
const worker_id = utils.jsonGetString(obj, "worker_id") orelse "";
|
||||
const err = utils.jsonGetString(obj, "error") orelse "";
|
||||
const priority = utils.jsonGetInt(obj, "priority") orelse 5;
|
||||
|
||||
// Show queue position for queued jobs
|
||||
const position_str = if (std.mem.eql(u8, want, "queued"))
|
||||
try std.fmt.allocPrint(std.heap.page_allocator, " [pos {d}]", .{position})
|
||||
else
|
||||
"";
|
||||
defer if (std.mem.eql(u8, want, "queued")) std.heap.page_allocator.free(position_str);
|
||||
|
||||
if (std.mem.eql(u8, want, "failed")) {
|
||||
colors.printWarning("- {s} {s}", .{ shorten(id, 8), job_name });
|
||||
colors.printWarning("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
if (worker_id.len > 0) {
|
||||
std.debug.print(" (worker={s})", .{worker_id});
|
||||
std.debug.print(" worker={s}", .{worker_id});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
if (err.len > 0) {
|
||||
std.debug.print(" error: {s}\n", .{shorten(err, 160)});
|
||||
}
|
||||
} else if (std.mem.eql(u8, want, "running")) {
|
||||
colors.printInfo("- {s} {s}", .{ shorten(id, 8), job_name });
|
||||
colors.printInfo("- {s} {s}{s} (P:{d})", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
if (worker_id.len > 0) {
|
||||
std.debug.print(" (worker={s})", .{worker_id});
|
||||
std.debug.print(" worker={s}", .{worker_id});
|
||||
}
|
||||
std.debug.print("\n", .{});
|
||||
} else if (std.mem.eql(u8, want, "queued")) {
|
||||
std.debug.print("- {s} {s}\n", .{ shorten(id, 8), job_name });
|
||||
std.debug.print("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
} else {
|
||||
colors.printSuccess("- {s} {s}\n", .{ shorten(id, 8), job_name });
|
||||
colors.printSuccess("- {s} {s}{s} (P:{d})\n", .{ shorten(id, 8), job_name, position_str, priority });
|
||||
}
|
||||
|
||||
shown += 1;
|
||||
if (shown >= limit2) break;
|
||||
}
|
||||
|
||||
if (shown == 0) {
|
||||
|
|
@ -189,7 +210,7 @@ fn parseAndDisplayStatusJson(allocator: std.mem.Allocator, json_data: []const u8
|
|||
|
||||
if (try Client.formatPrewarmFromStatusRoot(allocator, root)) |section| {
|
||||
defer allocator.free(section);
|
||||
colors.printInfo("{s}", .{section});
|
||||
colors.printInfo("\n{s}", .{section});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
22
cli/src/server.zig
Normal file
22
cli/src/server.zig
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
//! Server mode operations module
|
||||
//!
|
||||
//! Provides implementations for CLI commands when running in server mode (WebSocket).
|
||||
//! These functions are called by the command routers in `src/commands/` when
|
||||
//! `Context.isServer()` returns true.
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```zig
|
||||
//! const server = @import("../server.zig");
|
||||
//!
|
||||
//! if (ctx.isServer()) {
|
||||
//! return try server.experiment.list(ctx.allocator, ctx.json_output);
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! ## Module Structure
|
||||
//!
|
||||
//! - `experiment_api.zig` - Experiment API operations via WebSocket
|
||||
//! - Future: `run_api.zig`, `metrics_api.zig`, etc.
|
||||
|
||||
pub const experiment = @import("server/experiment_api.zig");
|
||||
124
cli/src/server/experiment_api.zig
Normal file
124
cli/src/server/experiment_api.zig
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
const std = @import("std");
|
||||
const ws = @import("../net/ws/client.zig");
|
||||
const protocol = @import("../net/protocol.zig");
|
||||
const config = @import("../config.zig");
|
||||
const crypto = @import("../utils/crypto.zig");
|
||||
const core = @import("../core.zig");
|
||||
|
||||
/// Log a metric to server mode
|
||||
pub fn logMetric(
|
||||
allocator: std.mem.Allocator,
|
||||
commit_id: []const u8,
|
||||
name: []const u8,
|
||||
value: f64,
|
||||
step: u32,
|
||||
json: bool,
|
||||
) !void {
|
||||
const cfg = try config.Config.load(allocator);
|
||||
defer {
|
||||
var mut_cfg = cfg;
|
||||
mut_cfg.deinit(allocator);
|
||||
}
|
||||
|
||||
const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key);
|
||||
defer allocator.free(api_key_hash);
|
||||
|
||||
const ws_url = try cfg.getWebSocketUrl(allocator);
|
||||
defer allocator.free(ws_url);
|
||||
|
||||
var client = try ws.Client.connect(allocator, ws_url, cfg.api_key);
|
||||
defer client.close();
|
||||
|
||||
try client.sendLogMetric(api_key_hash, commit_id, name, value, step);
|
||||
|
||||
if (json) {
|
||||
const message = try client.receiveMessage(allocator);
|
||||
defer allocator.free(message);
|
||||
|
||||
const packet = protocol.ResponsePacket.deserialize(message, allocator) catch {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.log\",\"data\":{{\"commit_id\":\"{s}\",\"metric\":{{\"name\":\"{s}\",\"value\":{d},\"step\":{d}}},\"message\":\"{s}\"}}}}\n",
|
||||
.{ commit_id, name, value, step, message },
|
||||
);
|
||||
return;
|
||||
};
|
||||
defer packet.deinit(allocator);
|
||||
|
||||
switch (packet.packet_type) {
|
||||
.success => {
|
||||
std.debug.print(
|
||||
"{{\"success\":true,\"command\":\"experiment.log\",\"data\":{{\"commit_id\":\"{s}\",\"metric\":{{\"name\":\"{s}\",\"value\":{d},\"step\":{d}}},\"message\":\"{s}\"}}}}\n",
|
||||
.{ commit_id, name, value, step, message },
|
||||
);
|
||||
return;
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
} else {
|
||||
try client.receiveAndHandleResponse(allocator, "Log metric");
|
||||
std.debug.print("Metric logged successfully!\n", .{});
|
||||
std.debug.print("Commit ID: {s}\n", .{commit_id});
|
||||
std.debug.print("Metric: {s} = {d:.4} (step {d})\n", .{ name, value, step });
|
||||
}
|
||||
}
|
||||
|
||||
/// List experiments from server mode
|
||||
pub fn list(allocator: std.mem.Allocator, json: bool) !void {
|
||||
const entries = @import("../utils/history.zig").loadEntries(allocator) catch |err| {
|
||||
if (json) {
|
||||
const details = try std.fmt.allocPrint(allocator, "{}", .{err});
|
||||
defer allocator.free(details);
|
||||
core.output.errorMsgDetailed("experiment.list", "Failed to read experiment history", details);
|
||||
} else {
|
||||
std.log.err("Failed to read experiment history: {}", .{err});
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer @import("../utils/history.zig").freeEntries(allocator, entries);
|
||||
|
||||
if (entries.len == 0) {
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"experiment.list\",\"data\":{{\"experiments\":[],\"total\":0,\"message\":\"No experiments recorded yet. Use `ml queue` to submit one.\"}}}}\n", .{});
|
||||
} else {
|
||||
std.debug.print("No experiments recorded yet. Use `ml queue` to submit one.\n", .{});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (json) {
|
||||
std.debug.print("{{\"success\":true,\"command\":\"experiment.list\",\"data\":{{\"experiments\":[", .{});
|
||||
var idx: usize = 0;
|
||||
while (idx < entries.len) : (idx += 1) {
|
||||
const entry = entries[entries.len - idx - 1];
|
||||
if (idx > 0) {
|
||||
std.debug.print(",", .{});
|
||||
}
|
||||
std.debug.print(
|
||||
"{{\"alias\":\"{s}\",\"commit_id\":\"{s}\",\"queued_at\":{d}}}",
|
||||
.{
|
||||
entry.job_name,
|
||||
entry.commit_id,
|
||||
entry.queued_at,
|
||||
},
|
||||
);
|
||||
}
|
||||
std.debug.print("],\"total\":{d}", .{entries.len});
|
||||
std.debug.print("}}}}\n", .{});
|
||||
} else {
|
||||
std.debug.print("\nRecent Experiments (latest first):\n", .{});
|
||||
std.debug.print("---------------------------------\n", .{});
|
||||
|
||||
const max_display = if (entries.len > 20) 20 else entries.len;
|
||||
var idx: usize = 0;
|
||||
while (idx < max_display) : (idx += 1) {
|
||||
const entry = entries[entries.len - idx - 1];
|
||||
std.debug.print("{d:2}) Alias: {s}\n", .{ idx + 1, entry.job_name });
|
||||
std.debug.print(" Commit: {s}\n", .{entry.commit_id});
|
||||
std.debug.print(" Queued: {d}\n\n", .{entry.queued_at});
|
||||
}
|
||||
|
||||
if (entries.len > max_display) {
|
||||
std.debug.print("...and {d} more\n", .{entries.len - max_display});
|
||||
}
|
||||
}
|
||||
}
|
||||
151
cli/src/ui/progress.zig
Normal file
151
cli/src/ui/progress.zig
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
const std = @import("std");
|
||||
const colors = @import("../utils/colors.zig");
|
||||
|
||||
/// ProgressBar provides visual feedback for long-running operations.
|
||||
/// It displays progress as a percentage, item count, and throughput rate.
|
||||
pub const ProgressBar = struct {
|
||||
total: usize,
|
||||
current: usize,
|
||||
label: []const u8,
|
||||
start_time: i64,
|
||||
width: usize,
|
||||
|
||||
/// Initialize a new progress bar
|
||||
pub fn init(total: usize, label: []const u8) ProgressBar {
|
||||
return .{
|
||||
.total = total,
|
||||
.current = 0,
|
||||
.label = label,
|
||||
.start_time = std.time.milliTimestamp(),
|
||||
.width = 40, // Default bar width
|
||||
};
|
||||
}
|
||||
|
||||
/// Update the progress bar with current progress
|
||||
pub fn update(self: *ProgressBar, current: usize) void {
|
||||
self.current = current;
|
||||
self.render();
|
||||
}
|
||||
|
||||
/// Increment progress by one step
|
||||
pub fn increment(self: *ProgressBar) void {
|
||||
self.current += 1;
|
||||
self.render();
|
||||
}
|
||||
|
||||
/// Render the progress bar to stderr
|
||||
fn render(self: ProgressBar) void {
|
||||
const percent = if (self.total > 0)
|
||||
@divFloor(self.current * 100, self.total)
|
||||
else
|
||||
0;
|
||||
|
||||
const elapsed_ms = std.time.milliTimestamp() - self.start_time;
|
||||
const rate = if (elapsed_ms > 0 and self.current > 0)
|
||||
@as(f64, @floatFromInt(self.current)) / (@as(f64, @floatFromInt(elapsed_ms)) / 1000.0)
|
||||
else
|
||||
0.0;
|
||||
|
||||
// Build progress bar
|
||||
const filled = if (self.total > 0)
|
||||
@divFloor(self.current * self.width, self.total)
|
||||
else
|
||||
0;
|
||||
const empty = self.width - filled;
|
||||
|
||||
var bar_buf: [64]u8 = undefined;
|
||||
var bar_stream = std.io.fixedBufferStream(&bar_buf);
|
||||
const bar_writer = bar_stream.writer();
|
||||
|
||||
// Write filled portion
|
||||
var i: usize = 0;
|
||||
while (i < filled) : (i += 1) {
|
||||
_ = bar_writer.write("=") catch {};
|
||||
}
|
||||
// Write empty portion
|
||||
i = 0;
|
||||
while (i < empty) : (i += 1) {
|
||||
_ = bar_writer.write("-") catch {};
|
||||
}
|
||||
|
||||
const bar = bar_stream.getWritten();
|
||||
|
||||
// Clear line and print progress
|
||||
const stderr = std.io.getStdErr().writer();
|
||||
stderr.print("\r{s} [{s}] {d}/{d} {d}% ({d:.1} items/s)", .{
|
||||
self.label,
|
||||
bar,
|
||||
self.current,
|
||||
self.total,
|
||||
percent,
|
||||
rate,
|
||||
}) catch {};
|
||||
}
|
||||
|
||||
/// Finish the progress bar and print a newline
|
||||
pub fn finish(self: ProgressBar) void {
|
||||
self.render();
|
||||
const stderr = std.io.getStdErr().writer();
|
||||
stderr.print("\n", .{}) catch {};
|
||||
}
|
||||
|
||||
/// Complete with a success message
|
||||
pub fn success(self: ProgressBar, msg: []const u8) void {
|
||||
self.current = self.total;
|
||||
self.render();
|
||||
colors.printSuccess("\n{s}\n", .{msg});
|
||||
}
|
||||
|
||||
/// Get elapsed time in milliseconds
|
||||
pub fn elapsedMs(self: ProgressBar) i64 {
|
||||
return std.time.milliTimestamp() - self.start_time;
|
||||
}
|
||||
|
||||
/// Get current throughput (items per second)
|
||||
pub fn throughput(self: ProgressBar) f64 {
|
||||
const elapsed_ms = self.elapsedMs();
|
||||
if (elapsed_ms > 0 and self.current > 0) {
|
||||
return @as(f64, @floatFromInt(self.current)) / (@as(f64, @floatFromInt(elapsed_ms)) / 1000.0);
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Spinner provides visual feedback for indeterminate operations
|
||||
pub const Spinner = struct {
|
||||
label: []const u8,
|
||||
start_time: i64,
|
||||
frames: []const u8,
|
||||
frame_idx: usize,
|
||||
|
||||
const DEFAULT_FRAMES = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏";
|
||||
|
||||
pub fn init(label: []const u8) Spinner {
|
||||
return .{
|
||||
.label = label,
|
||||
.start_time = std.time.milliTimestamp(),
|
||||
.frames = DEFAULT_FRAMES,
|
||||
.frame_idx = 0,
|
||||
};
|
||||
}
|
||||
|
||||
/// Render one frame of the spinner
|
||||
pub fn tick(self: *Spinner) void {
|
||||
const frame = self.frames[self.frame_idx % self.frames.len];
|
||||
const stderr = std.io.getStdErr().writer();
|
||||
stderr.print("\r{s} {c} ", .{ self.label, frame }) catch {};
|
||||
self.frame_idx += 1;
|
||||
}
|
||||
|
||||
/// Stop the spinner and print a newline
|
||||
pub fn stop(self: Spinner) void {
|
||||
_ = self; // Intentionally unused - for API consistency
|
||||
const stderr = std.io.getStdErr().writer();
|
||||
stderr.print("\n", .{}) catch {};
|
||||
}
|
||||
|
||||
/// Get elapsed time in seconds
|
||||
pub fn elapsedSec(self: Spinner) i64 {
|
||||
return @divFloor(std.time.milliTimestamp() - self.start_time, 1000);
|
||||
}
|
||||
};
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
const std = @import("std");
|
||||
const ignore = @import("ignore.zig");
|
||||
|
||||
pub fn encodeHexLower(allocator: std.mem.Allocator, bytes: []const u8) ![]u8 {
|
||||
const hex = try allocator.alloc(u8, bytes.len * 2);
|
||||
|
|
@ -46,15 +47,81 @@ pub fn hashApiKey(allocator: std.mem.Allocator, api_key: []const u8) ![]u8 {
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Calculate commit ID for a directory (SHA256 of tree state)
|
||||
/// Calculate SHA256 hash of a file
|
||||
pub fn hashFile(allocator: std.mem.Allocator, file_path: []const u8) ![]u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
const file = try std.fs.cwd().openFile(file_path, .{});
|
||||
defer file.close();
|
||||
|
||||
var buf: [4096]u8 = undefined;
|
||||
while (true) {
|
||||
const bytes_read = try file.read(&buf);
|
||||
if (bytes_read == 0) break;
|
||||
hasher.update(buf[0..bytes_read]);
|
||||
}
|
||||
|
||||
var hash: [32]u8 = undefined;
|
||||
hasher.final(&hash);
|
||||
return encodeHexLower(allocator, &hash);
|
||||
}
|
||||
|
||||
/// Calculate combined hash of multiple files (sorted by path)
|
||||
pub fn hashFiles(allocator: std.mem.Allocator, dir_path: []const u8, file_paths: []const []const u8) ![]u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
// Copy and sort paths for deterministic hashing
|
||||
var sorted_paths = std.ArrayList([]const u8).initCapacity(allocator, file_paths.len) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer sorted_paths.deinit(allocator);
|
||||
|
||||
for (file_paths) |path| {
|
||||
try sorted_paths.append(allocator, path);
|
||||
}
|
||||
|
||||
std.sort.block([]const u8, sorted_paths.items, {}, struct {
|
||||
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
|
||||
return std.mem.order(u8, a, b) == .lt;
|
||||
}
|
||||
}.lessThan);
|
||||
|
||||
// Hash each file
|
||||
for (sorted_paths.items) |path| {
|
||||
hasher.update(path);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const file_hash = try hashFile(allocator, full_path);
|
||||
defer allocator.free(file_hash);
|
||||
|
||||
hasher.update(file_hash);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
}
|
||||
|
||||
var hash: [32]u8 = undefined;
|
||||
hasher.final(&hash);
|
||||
return encodeHexLower(allocator, &hash);
|
||||
}
|
||||
|
||||
/// Calculate commit ID for a directory (SHA256 of tree state, respecting .gitignore)
|
||||
pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
// Load .gitignore and .mlignore patterns
|
||||
var gitignore = ignore.GitIgnore.init(allocator);
|
||||
defer gitignore.deinit();
|
||||
|
||||
try gitignore.loadFromDir(dir_path, ".gitignore");
|
||||
try gitignore.loadFromDir(dir_path, ".mlignore");
|
||||
|
||||
var walker = try dir.walk(allocator);
|
||||
defer walker.deinit();
|
||||
defer walker.deinit(allocator);
|
||||
|
||||
// Collect and sort paths for deterministic hashing
|
||||
var paths: std.ArrayList([]const u8) = .{};
|
||||
|
|
@ -65,6 +132,12 @@ pub fn hashDirectory(allocator: std.mem.Allocator, dir_path: []const u8) ![]u8 {
|
|||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind == .file) {
|
||||
// Skip files matching default ignores
|
||||
if (ignore.matchesDefaultIgnore(entry.path)) continue;
|
||||
|
||||
// Skip files matching .gitignore/.mlignore patterns
|
||||
if (gitignore.isIgnored(entry.path, false)) continue;
|
||||
|
||||
try paths.append(allocator, try allocator.dupe(u8, entry.path));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
333
cli/src/utils/hash_cache.zig
Normal file
333
cli/src/utils/hash_cache.zig
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
const std = @import("std");
|
||||
const crypto = @import("crypto.zig");
|
||||
const json = @import("json.zig");
|
||||
|
||||
/// Cache entry for a single file
|
||||
const CacheEntry = struct {
|
||||
mtime: i64,
|
||||
hash: []const u8,
|
||||
|
||||
pub fn deinit(self: *const CacheEntry, allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.hash);
|
||||
}
|
||||
};
|
||||
|
||||
/// Hash cache that stores file mtimes and hashes to avoid re-hashing unchanged files
|
||||
pub const HashCache = struct {
|
||||
entries: std.StringHashMap(CacheEntry),
|
||||
allocator: std.mem.Allocator,
|
||||
cache_path: []const u8,
|
||||
dirty: bool,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) HashCache {
|
||||
return .{
|
||||
.entries = std.StringHashMap(CacheEntry).init(allocator),
|
||||
.allocator = allocator,
|
||||
.cache_path = "",
|
||||
.dirty = false,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *HashCache) void {
|
||||
var it = self.entries.iterator();
|
||||
while (it.next()) |entry| {
|
||||
entry.value_ptr.deinit(self.allocator);
|
||||
self.allocator.free(entry.key_ptr.*);
|
||||
}
|
||||
self.entries.deinit();
|
||||
if (self.cache_path.len > 0) {
|
||||
self.allocator.free(self.cache_path);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get default cache path: ~/.ml/cache/hashes.json
|
||||
pub fn getDefaultPath(allocator: std.mem.Allocator) ![]const u8 {
|
||||
const home = std.posix.getenv("HOME") orelse {
|
||||
return error.NoHomeDirectory;
|
||||
};
|
||||
|
||||
// Ensure cache directory exists
|
||||
const cache_dir = try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache" });
|
||||
defer allocator.free(cache_dir);
|
||||
|
||||
std.fs.cwd().makeDir(cache_dir) catch |err| switch (err) {
|
||||
error.PathAlreadyExists => {},
|
||||
else => return err,
|
||||
};
|
||||
|
||||
return try std.fs.path.join(allocator, &[_][]const u8{ home, ".ml", "cache", "hashes.json" });
|
||||
}
|
||||
|
||||
/// Load cache from disk
|
||||
pub fn load(self: *HashCache) !void {
|
||||
const cache_path = try getDefaultPath(self.allocator);
|
||||
self.cache_path = cache_path;
|
||||
|
||||
const file = std.fs.cwd().openFile(cache_path, .{}) catch |err| switch (err) {
|
||||
error.FileNotFound => return, // No cache yet is fine
|
||||
else => return err,
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
const content = try file.readToEndAlloc(self.allocator, 10 * 1024 * 1024); // Max 10MB
|
||||
defer self.allocator.free(content);
|
||||
|
||||
// Parse JSON
|
||||
const parsed = try std.json.parseFromSlice(std.json.Value, self.allocator, content, .{});
|
||||
defer parsed.deinit();
|
||||
|
||||
const root = parsed.value.object;
|
||||
const version = root.get("version") orelse return error.InvalidCacheFormat;
|
||||
if (version.integer != 1) return error.UnsupportedCacheVersion;
|
||||
|
||||
const files = root.get("files") orelse return error.InvalidCacheFormat;
|
||||
if (files.object.count() == 0) return;
|
||||
|
||||
var it = files.object.iterator();
|
||||
while (it.next()) |entry| {
|
||||
const path = try self.allocator.dupe(u8, entry.key_ptr.*);
|
||||
|
||||
const file_obj = entry.value_ptr.object;
|
||||
const mtime = file_obj.get("mtime") orelse continue;
|
||||
const hash_val = file_obj.get("hash") orelse continue;
|
||||
|
||||
const hash = try self.allocator.dupe(u8, hash_val.string);
|
||||
|
||||
try self.entries.put(path, .{
|
||||
.mtime = mtime.integer,
|
||||
.hash = hash,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Save cache to disk
|
||||
pub fn save(self: *HashCache) !void {
|
||||
if (!self.dirty) return;
|
||||
|
||||
var json_str = std.ArrayList(u8).init(self.allocator);
|
||||
defer json_str.deinit();
|
||||
|
||||
var writer = json_str.writer();
|
||||
|
||||
// Write header
|
||||
try writer.print("{{\n \"version\": 1,\n \"files\": {{\n", .{});
|
||||
|
||||
// Write entries
|
||||
var it = self.entries.iterator();
|
||||
var first = true;
|
||||
while (it.next()) |entry| {
|
||||
if (!first) try writer.print(",\n", .{});
|
||||
first = false;
|
||||
|
||||
// Escape path for JSON
|
||||
const escaped_path = try json.escapeString(self.allocator, entry.key_ptr.*);
|
||||
defer self.allocator.free(escaped_path);
|
||||
|
||||
try writer.print(" \"{s}\": {{\"mtime\": {d}, \"hash\": \"{s}\"}}", .{
|
||||
escaped_path,
|
||||
entry.value_ptr.mtime,
|
||||
entry.value_ptr.hash,
|
||||
});
|
||||
}
|
||||
|
||||
// Write footer
|
||||
try writer.print("\n }}\n}}\n", .{});
|
||||
|
||||
// Write atomically
|
||||
const tmp_path = try std.fmt.allocPrint(self.allocator, "{s}.tmp", .{self.cache_path});
|
||||
defer self.allocator.free(tmp_path);
|
||||
|
||||
{
|
||||
const file = try std.fs.cwd().createFile(tmp_path, .{});
|
||||
defer file.close();
|
||||
try file.writeAll(json_str.items);
|
||||
}
|
||||
|
||||
try std.fs.cwd().rename(tmp_path, self.cache_path);
|
||||
self.dirty = false;
|
||||
}
|
||||
|
||||
/// Check if file needs re-hashing
|
||||
pub fn needsHash(self: *HashCache, path: []const u8, mtime: i64) bool {
|
||||
const entry = self.entries.get(path) orelse return true;
|
||||
return entry.mtime != mtime;
|
||||
}
|
||||
|
||||
/// Get cached hash for file
|
||||
pub fn getHash(self: *HashCache, path: []const u8, mtime: i64) ?[]const u8 {
|
||||
const entry = self.entries.get(path) orelse return null;
|
||||
if (entry.mtime != mtime) return null;
|
||||
return entry.hash;
|
||||
}
|
||||
|
||||
/// Store hash for file
|
||||
pub fn putHash(self: *HashCache, path: []const u8, mtime: i64, hash: []const u8) !void {
|
||||
const path_copy = try self.allocator.dupe(u8, path);
|
||||
|
||||
// Remove old entry if exists
|
||||
if (self.entries.fetchRemove(path_copy)) |old| {
|
||||
self.allocator.free(old.key);
|
||||
old.value.deinit(self.allocator);
|
||||
}
|
||||
|
||||
const hash_copy = try self.allocator.dupe(u8, hash);
|
||||
|
||||
try self.entries.put(path_copy, .{
|
||||
.mtime = mtime,
|
||||
.hash = hash_copy,
|
||||
});
|
||||
|
||||
self.dirty = true;
|
||||
}
|
||||
|
||||
/// Clear cache (e.g., after git checkout)
|
||||
pub fn clear(self: *HashCache) void {
|
||||
var it = self.entries.iterator();
|
||||
while (it.next()) |entry| {
|
||||
entry.value_ptr.deinit(self.allocator);
|
||||
self.allocator.free(entry.key_ptr.*);
|
||||
}
|
||||
self.entries.clearRetainingCapacity();
|
||||
self.dirty = true;
|
||||
}
|
||||
|
||||
/// Get cache stats
|
||||
pub fn getStats(self: *HashCache) struct { entries: usize, dirty: bool } {
|
||||
return .{
|
||||
.entries = self.entries.count(),
|
||||
.dirty = self.dirty,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// Calculate directory hash with cache support
|
||||
pub fn hashDirectoryWithCache(
|
||||
allocator: std.mem.Allocator,
|
||||
dir_path: []const u8,
|
||||
cache: *HashCache,
|
||||
) ![]const u8 {
|
||||
var hasher = std.crypto.hash.sha2.Sha256.init(.{});
|
||||
|
||||
var dir = try std.fs.cwd().openDir(dir_path, .{ .iterate = true });
|
||||
defer dir.close();
|
||||
|
||||
// Load .gitignore patterns
|
||||
var gitignore = @import("ignore.zig").GitIgnore.init(allocator);
|
||||
defer gitignore.deinit();
|
||||
|
||||
try gitignore.loadFromDir(dir_path, ".gitignore");
|
||||
try gitignore.loadFromDir(dir_path, ".mlignore");
|
||||
|
||||
var walker = try dir.walk(allocator);
|
||||
defer walker.deinit(allocator);
|
||||
|
||||
// Collect paths and check cache
|
||||
var paths: std.ArrayList(struct { path: []const u8, mtime: i64, use_cache: bool }) = .{};
|
||||
defer {
|
||||
for (paths.items) |p| allocator.free(p.path);
|
||||
paths.deinit(allocator);
|
||||
}
|
||||
|
||||
while (try walker.next()) |entry| {
|
||||
if (entry.kind == .file) {
|
||||
// Skip files matching default ignores
|
||||
if (@import("ignore.zig").matchesDefaultIgnore(entry.path)) continue;
|
||||
|
||||
// Skip files matching .gitignore/.mlignore patterns
|
||||
if (gitignore.isIgnored(entry.path, false)) continue;
|
||||
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, entry.path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const stat = dir.statFile(entry.path) catch |err| switch (err) {
|
||||
error.FileNotFound => continue,
|
||||
else => return err,
|
||||
};
|
||||
|
||||
const mtime = @as(i64, @intCast(stat.mtime));
|
||||
const use_cache = !cache.needsHash(entry.path, mtime);
|
||||
|
||||
try paths.append(.{
|
||||
.path = try allocator.dupe(u8, entry.path),
|
||||
.mtime = mtime,
|
||||
.use_cache = use_cache,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort paths for deterministic hashing
|
||||
std.sort.block(
|
||||
struct { path: []const u8, mtime: i64, use_cache: bool },
|
||||
paths.items,
|
||||
{},
|
||||
struct {
|
||||
fn lessThan(_: void, a: anytype, b: anytype) bool {
|
||||
return std.mem.order(u8, a.path, b.path) == .lt;
|
||||
}
|
||||
}.lessThan,
|
||||
);
|
||||
|
||||
// Hash each file (using cache where possible)
|
||||
for (paths.items) |item| {
|
||||
hasher.update(item.path);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
|
||||
const file_hash: []const u8 = if (item.use_cache)
|
||||
cache.getHash(item.path, item.mtime).?
|
||||
else blk: {
|
||||
const full_path = try std.fs.path.join(allocator, &[_][]const u8{ dir_path, item.path });
|
||||
defer allocator.free(full_path);
|
||||
|
||||
const hash = try crypto.hashFile(allocator, full_path);
|
||||
try cache.putHash(item.path, item.mtime, hash);
|
||||
break :blk hash;
|
||||
};
|
||||
defer if (!item.use_cache) allocator.free(file_hash);
|
||||
|
||||
hasher.update(file_hash);
|
||||
hasher.update(&[_]u8{0}); // Separator
|
||||
}
|
||||
|
||||
var hash: [32]u8 = undefined;
|
||||
hasher.final(&hash);
|
||||
return crypto.encodeHexLower(allocator, &hash);
|
||||
}
|
||||
|
||||
test "HashCache basic operations" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var cache = HashCache.init(allocator);
|
||||
defer cache.deinit();
|
||||
|
||||
// Put and get
|
||||
try cache.putHash("src/main.py", 1708369200, "abc123");
|
||||
|
||||
const hash = cache.getHash("src/main.py", 1708369200);
|
||||
try std.testing.expect(hash != null);
|
||||
try std.testing.expectEqualStrings("abc123", hash.?);
|
||||
|
||||
// Wrong mtime should return null
|
||||
const stale = cache.getHash("src/main.py", 1708369201);
|
||||
try std.testing.expect(stale == null);
|
||||
|
||||
// needsHash should detect stale entries
|
||||
try std.testing.expect(cache.needsHash("src/main.py", 1708369201));
|
||||
try std.testing.expect(!cache.needsHash("src/main.py", 1708369200));
|
||||
}
|
||||
|
||||
test "HashCache clear" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var cache = HashCache.init(allocator);
|
||||
defer cache.deinit();
|
||||
|
||||
try cache.putHash("file1.py", 123, "hash1");
|
||||
try cache.putHash("file2.py", 456, "hash2");
|
||||
|
||||
try std.testing.expectEqual(@as(usize, 2), cache.getStats().entries);
|
||||
|
||||
cache.clear();
|
||||
|
||||
try std.testing.expectEqual(@as(usize, 0), cache.getStats().entries);
|
||||
try std.testing.expect(cache.getStats().dirty);
|
||||
}
|
||||
261
cli/src/utils/ignore.zig
Normal file
261
cli/src/utils/ignore.zig
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Pattern type for ignore rules
|
||||
const Pattern = struct {
|
||||
pattern: []const u8,
|
||||
is_negation: bool, // true if pattern starts with !
|
||||
is_dir_only: bool, // true if pattern ends with /
|
||||
anchored: bool, // true if pattern contains / (not at start)
|
||||
};
|
||||
|
||||
/// GitIgnore matcher for filtering files during directory traversal
|
||||
pub const GitIgnore = struct {
|
||||
patterns: std.ArrayList(Pattern),
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) GitIgnore {
|
||||
return .{
|
||||
.patterns = std.ArrayList(Pattern).init(allocator),
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *GitIgnore) void {
|
||||
for (self.patterns.items) |p| {
|
||||
self.allocator.free(p.pattern);
|
||||
}
|
||||
self.patterns.deinit();
|
||||
}
|
||||
|
||||
/// Load .gitignore or .mlignore from directory
|
||||
pub fn loadFromDir(self: *GitIgnore, dir_path: []const u8, filename: []const u8) !void {
|
||||
const path = try std.fs.path.join(self.allocator, &[_][]const u8{ dir_path, filename });
|
||||
defer self.allocator.free(path);
|
||||
|
||||
const file = std.fs.cwd().openFile(path, .{}) catch |err| switch (err) {
|
||||
error.FileNotFound => return, // No ignore file is fine
|
||||
else => return err,
|
||||
};
|
||||
defer file.close();
|
||||
|
||||
const content = try file.readToEndAlloc(self.allocator, 1024 * 1024); // Max 1MB
|
||||
defer self.allocator.free(content);
|
||||
|
||||
try self.parse(content);
|
||||
}
|
||||
|
||||
/// Parse ignore patterns from content
|
||||
pub fn parse(self: *GitIgnore, content: []const u8) !void {
|
||||
var lines = std.mem.split(u8, content, "\n");
|
||||
while (lines.next()) |line| {
|
||||
const trimmed = std.mem.trim(u8, line, " \t\r");
|
||||
|
||||
// Skip empty lines and comments
|
||||
if (trimmed.len == 0 or std.mem.startsWith(u8, trimmed, "#")) continue;
|
||||
|
||||
try self.addPattern(trimmed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single pattern
|
||||
fn addPattern(self: *GitIgnore, pattern: []const u8) !void {
|
||||
var p = pattern;
|
||||
var is_negation = false;
|
||||
var is_dir_only = false;
|
||||
|
||||
// Check for negation
|
||||
if (std.mem.startsWith(u8, p, "!")) {
|
||||
is_negation = true;
|
||||
p = p[1..];
|
||||
}
|
||||
|
||||
// Check for directory-only marker
|
||||
if (std.mem.endsWith(u8, p, "/")) {
|
||||
is_dir_only = true;
|
||||
p = p[0 .. p.len - 1];
|
||||
}
|
||||
|
||||
// Remove leading slash (anchored patterns)
|
||||
const anchored = std.mem.indexOf(u8, p, "/") != null;
|
||||
if (std.mem.startsWith(u8, p, "/")) {
|
||||
p = p[1..];
|
||||
}
|
||||
|
||||
// Store normalized pattern
|
||||
const pattern_copy = try self.allocator.dupe(u8, p);
|
||||
try self.patterns.append(.{
|
||||
.pattern = pattern_copy,
|
||||
.is_negation = is_negation,
|
||||
.is_dir_only = is_dir_only,
|
||||
.anchored = anchored,
|
||||
});
|
||||
}
|
||||
|
||||
/// Check if a path should be ignored
|
||||
pub fn isIgnored(self: *GitIgnore, path: []const u8, is_dir: bool) bool {
|
||||
var ignored = false;
|
||||
|
||||
for (self.patterns.items) |pattern| {
|
||||
if (self.matches(pattern, path, is_dir)) {
|
||||
ignored = !pattern.is_negation;
|
||||
}
|
||||
}
|
||||
|
||||
return ignored;
|
||||
}
|
||||
|
||||
/// Check if a single pattern matches
|
||||
fn matches(self: *GitIgnore, pattern: Pattern, path: []const u8, is_dir: bool) bool {
|
||||
_ = self;
|
||||
|
||||
// Directory-only patterns only match directories
|
||||
if (pattern.is_dir_only and !is_dir) return false;
|
||||
|
||||
// Convert gitignore pattern to glob
|
||||
if (patternMatch(pattern.pattern, path)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Also check basename match for non-anchored patterns
|
||||
if (!pattern.anchored) {
|
||||
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
|
||||
const basename = path[idx + 1 ..];
|
||||
if (patternMatch(pattern.pattern, basename)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Simple glob pattern matching
|
||||
fn patternMatch(pattern: []const u8, path: []const u8) bool {
|
||||
var p_idx: usize = 0;
|
||||
var s_idx: usize = 0;
|
||||
|
||||
while (p_idx < pattern.len) {
|
||||
const p_char = pattern[p_idx];
|
||||
|
||||
if (p_char == '*') {
|
||||
// Handle ** (matches any number of directories)
|
||||
if (p_idx + 1 < pattern.len and pattern[p_idx + 1] == '*') {
|
||||
// ** matches everything
|
||||
return true;
|
||||
}
|
||||
|
||||
// Single * matches anything until next / or end
|
||||
p_idx += 1;
|
||||
if (p_idx >= pattern.len) {
|
||||
// * at end - match rest of path
|
||||
return true;
|
||||
}
|
||||
|
||||
const next_char = pattern[p_idx];
|
||||
while (s_idx < path.len and path[s_idx] != next_char) {
|
||||
s_idx += 1;
|
||||
}
|
||||
} else if (p_char == '?') {
|
||||
// ? matches single character
|
||||
if (s_idx >= path.len) return false;
|
||||
p_idx += 1;
|
||||
s_idx += 1;
|
||||
} else {
|
||||
// Literal character match
|
||||
if (s_idx >= path.len or path[s_idx] != p_char) return false;
|
||||
p_idx += 1;
|
||||
s_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return s_idx == path.len;
|
||||
}
|
||||
};
|
||||
|
||||
/// Default patterns always ignored (like git does)
|
||||
pub const DEFAULT_IGNORES = [_][]const u8{
|
||||
".git",
|
||||
".ml",
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
".DS_Store",
|
||||
"node_modules",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
};
|
||||
|
||||
/// Check if path matches default ignores
|
||||
pub fn matchesDefaultIgnore(path: []const u8) bool {
|
||||
// Check exact matches
|
||||
for (DEFAULT_IGNORES) |pattern| {
|
||||
if (std.mem.eql(u8, path, pattern)) return true;
|
||||
}
|
||||
|
||||
// Check suffix matches for patterns like *.pyc
|
||||
if (std.mem.lastIndexOf(u8, path, "/")) |idx| {
|
||||
const basename = path[idx + 1 ..];
|
||||
for (DEFAULT_IGNORES) |pattern| {
|
||||
if (std.mem.startsWith(u8, pattern, "*.")) {
|
||||
const ext = pattern[1..]; // Get extension including dot
|
||||
if (std.mem.endsWith(u8, basename, ext)) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
test "GitIgnore basic patterns" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gi = GitIgnore.init(allocator);
|
||||
defer gi.deinit();
|
||||
|
||||
try gi.parse("node_modules\n__pycache__\n*.pyc\n");
|
||||
|
||||
try std.testing.expect(gi.isIgnored("node_modules", true));
|
||||
try std.testing.expect(gi.isIgnored("__pycache__", true));
|
||||
try std.testing.expect(gi.isIgnored("test.pyc", false));
|
||||
try std.testing.expect(!gi.isIgnored("main.py", false));
|
||||
}
|
||||
|
||||
test "GitIgnore negation" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gi = GitIgnore.init(allocator);
|
||||
defer gi.deinit();
|
||||
|
||||
try gi.parse("*.log\n!important.log\n");
|
||||
|
||||
try std.testing.expect(gi.isIgnored("debug.log", false));
|
||||
try std.testing.expect(!gi.isIgnored("important.log", false));
|
||||
}
|
||||
|
||||
test "GitIgnore directory-only" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gi = GitIgnore.init(allocator);
|
||||
defer gi.deinit();
|
||||
|
||||
try gi.parse("build/\n");
|
||||
|
||||
try std.testing.expect(gi.isIgnored("build", true));
|
||||
try std.testing.expect(!gi.isIgnored("build", false));
|
||||
}
|
||||
|
||||
test "matchesDefaultIgnore" {
|
||||
try std.testing.expect(matchesDefaultIgnore(".git"));
|
||||
try std.testing.expect(matchesDefaultIgnore("__pycache__"));
|
||||
try std.testing.expect(matchesDefaultIgnore("node_modules"));
|
||||
try std.testing.expect(matchesDefaultIgnore("test.pyc"));
|
||||
try std.testing.expect(!matchesDefaultIgnore("main.py"));
|
||||
}
|
||||
|
|
@ -57,3 +57,76 @@ pub fn stdoutWriter() std.Io.Writer {
|
|||
pub fn stderrWriter() std.Io.Writer {
|
||||
return .{ .vtable = &stderr_vtable, .buffer = &[_]u8{}, .end = 0 };
|
||||
}
|
||||
|
||||
/// Write a JSON value to stdout
|
||||
pub fn stdoutWriteJson(value: std.json.Value) !void {
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(std.heap.page_allocator);
|
||||
try writeJSONValue(buf.writer(std.heap.page_allocator), value);
|
||||
var stdout_file = std.fs.File{ .handle = std.posix.STDOUT_FILENO };
|
||||
try stdout_file.writeAll(buf.items);
|
||||
try stdout_file.writeAll("\n");
|
||||
}
|
||||
|
||||
fn writeJSONValue(writer: anytype, v: std.json.Value) !void {
|
||||
switch (v) {
|
||||
.null => try writer.writeAll("null"),
|
||||
.bool => |b| try writer.print("{}", .{b}),
|
||||
.integer => |i| try writer.print("{d}", .{i}),
|
||||
.float => |f| try writer.print("{d}", .{f}),
|
||||
.string => |s| try writeJSONString(writer, s),
|
||||
.array => |arr| {
|
||||
try writer.writeAll("[");
|
||||
for (arr.items, 0..) |item, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writeJSONValue(writer, item);
|
||||
}
|
||||
try writer.writeAll("]");
|
||||
},
|
||||
.object => |obj| {
|
||||
try writer.writeAll("{");
|
||||
var first = true;
|
||||
var it = obj.iterator();
|
||||
while (it.next()) |entry| {
|
||||
if (!first) try writer.writeAll(",");
|
||||
first = false;
|
||||
try writer.print("\"{s}\":", .{entry.key_ptr.*});
|
||||
try writeJSONValue(writer, entry.value_ptr.*);
|
||||
}
|
||||
try writer.writeAll("}");
|
||||
},
|
||||
.number_string => |s| try writer.print("{s}", .{s}),
|
||||
}
|
||||
}
|
||||
|
||||
fn writeJSONString(writer: anytype, s: []const u8) !void {
|
||||
try writer.writeAll("\"");
|
||||
for (s) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
var buf: [6]u8 = undefined;
|
||||
buf[0] = '\\';
|
||||
buf[1] = 'u';
|
||||
buf[2] = '0';
|
||||
buf[3] = '0';
|
||||
buf[4] = hexDigit(@intCast((c >> 4) & 0x0F));
|
||||
buf[5] = hexDigit(@intCast(c & 0x0F));
|
||||
try writer.writeAll(&buf);
|
||||
} else {
|
||||
try writer.writeAll(&[_]u8{c});
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
}
|
||||
|
||||
fn hexDigit(v: u8) u8 {
|
||||
return if (v < 10) ('0' + v) else ('a' + (v - 10));
|
||||
}
|
||||
|
|
|
|||
122
cli/src/utils/native_bridge.zig
Normal file
122
cli/src/utils/native_bridge.zig
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
//! Native library bridge for high-performance operations
|
||||
//!
|
||||
//! Provides Zig bindings to the native/ C++ libraries:
|
||||
//! - dataset_hash: SIMD-accelerated SHA256 hashing
|
||||
//! - queue_index: High-performance task queue
|
||||
//!
|
||||
//! The native libraries provide:
|
||||
//! - 78% syscall reduction for hashing
|
||||
//! - 21,000x faster queue operations
|
||||
//! - Hardware acceleration (SHA-NI, ARMv8 crypto)
|
||||
|
||||
const std = @import("std");
|
||||
|
||||
// Link against native dataset_hash library
|
||||
const c = @cImport({
|
||||
@cInclude("dataset_hash.h");
|
||||
});
|
||||
|
||||
/// Opaque handle for native hash context
|
||||
pub const HashContext = opaque {};
|
||||
|
||||
/// Initialize hash context with thread pool
|
||||
/// num_threads: 0 = auto-detect (capped at 8)
|
||||
pub fn initHashContext(num_threads: u32) ?*HashContext {
|
||||
return @ptrCast(c.fh_init(num_threads));
|
||||
}
|
||||
|
||||
/// Cleanup hash context
|
||||
pub fn cleanupHashContext(ctx: ?*HashContext) void {
|
||||
if (ctx) |ptr| {
|
||||
c.fh_cleanup(@ptrCast(ptr));
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash a single file using native SIMD implementation
|
||||
/// Returns hex string (caller must free with freeString)
|
||||
pub fn hashFile(ctx: ?*HashContext, path: []const u8) ![]const u8 {
|
||||
const c_path = try std.heap.c_allocator.dupeZ(u8, path);
|
||||
defer std.heap.c_allocator.free(c_path);
|
||||
|
||||
const result = c.fh_hash_file(@ptrCast(ctx), c_path.ptr);
|
||||
if (result == null) {
|
||||
return error.HashFailed;
|
||||
}
|
||||
defer c.fh_free_string(result);
|
||||
|
||||
const len = std.mem.len(result);
|
||||
return try std.heap.c_allocator.dupe(u8, result[0..len]);
|
||||
}
|
||||
|
||||
/// Hash entire directory (parallel, combined result)
|
||||
pub fn hashDirectory(ctx: ?*HashContext, path: []const u8) ![]const u8 {
|
||||
const c_path = try std.heap.c_allocator.dupeZ(u8, path);
|
||||
defer std.heap.c_allocator.free(c_path);
|
||||
|
||||
const result = c.fh_hash_directory(@ptrCast(ctx), c_path.ptr);
|
||||
if (result == null) {
|
||||
return error.HashFailed;
|
||||
}
|
||||
defer c.fh_free_string(result);
|
||||
|
||||
const len = std.mem.len(result);
|
||||
return try std.heap.c_allocator.dupe(u8, result[0..len]);
|
||||
}
|
||||
|
||||
/// Free string returned by native library
|
||||
pub fn freeString(str: []const u8) void {
|
||||
std.heap.c_allocator.free(str);
|
||||
}
|
||||
|
||||
/// Hash data using native library (convenience function)
|
||||
pub fn hashData(data: []const u8) ![64]u8 {
|
||||
// Write data to temp file and hash it
|
||||
const tmp_path = try std.fs.path.join(std.heap.c_allocator, &.{ "/tmp", "fetchml_hash_tmp" });
|
||||
defer std.heap.c_allocator.free(tmp_path);
|
||||
|
||||
try std.fs.cwd().writeFile(.{
|
||||
.sub_path = tmp_path,
|
||||
.data = data,
|
||||
});
|
||||
defer std.fs.cwd().deleteFile(tmp_path) catch {};
|
||||
|
||||
const ctx = initHashContext(0) orelse return error.InitFailed;
|
||||
defer cleanupHashContext(ctx);
|
||||
|
||||
const hash_str = try hashFile(ctx, tmp_path);
|
||||
defer freeString(hash_str);
|
||||
|
||||
// Parse hex string to bytes
|
||||
var result: [64]u8 = undefined;
|
||||
@memcpy(&result, hash_str[0..64]);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Benchmark native vs standard hashing
|
||||
pub fn benchmark(allocator: std.mem.Allocator, path: []const u8, iterations: u32) !void {
|
||||
const ctx = initHashContext(0) orelse {
|
||||
std.debug.print("Failed to initialize native hash context\n", .{});
|
||||
return;
|
||||
};
|
||||
defer cleanupHashContext(ctx);
|
||||
|
||||
var timer = try std.time.Timer.start();
|
||||
|
||||
// Warm up
|
||||
_ = try hashFile(ctx, path);
|
||||
|
||||
// Benchmark native
|
||||
timer.reset();
|
||||
for (0..iterations) |_| {
|
||||
const hash = try hashFile(ctx, path);
|
||||
freeString(hash);
|
||||
}
|
||||
const native_time = timer.read();
|
||||
|
||||
std.debug.print("Native SIMD SHA256: {} ms for {d} iterations\n", .{
|
||||
native_time / std.time.ns_per_ms,
|
||||
iterations,
|
||||
});
|
||||
|
||||
_ = allocator; // Reserved for future comparison with Zig implementation
|
||||
}
|
||||
195
cli/src/utils/native_hash.zig
Normal file
195
cli/src/utils/native_hash.zig
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
const std = @import("std");
|
||||
const c = @cImport({
|
||||
@cInclude("dataset_hash.h");
|
||||
});
|
||||
|
||||
/// Native hash context for high-performance file hashing
|
||||
pub const NativeHasher = struct {
|
||||
ctx: *c.fh_context_t,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
/// Initialize native hasher with thread pool
|
||||
/// num_threads: 0 = auto-detect (use hardware concurrency)
|
||||
pub fn init(allocator: std.mem.Allocator, num_threads: u32) !NativeHasher {
|
||||
const ctx = c.fh_init(num_threads);
|
||||
if (ctx == null) return error.NativeInitFailed;
|
||||
|
||||
return .{
|
||||
.ctx = ctx,
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
/// Cleanup native hasher and thread pool
|
||||
pub fn deinit(self: *NativeHasher) void {
|
||||
c.fh_cleanup(self.ctx);
|
||||
}
|
||||
|
||||
/// Hash a single file
|
||||
pub fn hashFile(self: *NativeHasher, path: []const u8) ![]const u8 {
|
||||
const c_path = try self.allocator.dupeZ(u8, path);
|
||||
defer self.allocator.free(c_path);
|
||||
|
||||
const result = c.fh_hash_file(self.ctx, c_path.ptr);
|
||||
if (result == null) return error.HashFailed;
|
||||
defer c.fh_free_string(result);
|
||||
|
||||
return try self.allocator.dupe(u8, std.mem.span(result));
|
||||
}
|
||||
|
||||
/// Batch hash multiple files (amortizes CGo overhead)
|
||||
pub fn hashBatch(self: *NativeHasher, paths: []const []const u8) ![][]const u8 {
|
||||
// Convert paths to C string array
|
||||
const c_paths = try self.allocator.alloc([*c]const u8, paths.len);
|
||||
defer self.allocator.free(c_paths);
|
||||
|
||||
for (paths, 0..) |path, i| {
|
||||
const c_path = try self.allocator.dupeZ(u8, path);
|
||||
c_paths[i] = c_path.ptr;
|
||||
// Note: we need to keep these alive until after fh_hash_batch
|
||||
}
|
||||
defer {
|
||||
for (c_paths) |p| {
|
||||
self.allocator.free(std.mem.span(p));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate results array
|
||||
const results = try self.allocator.alloc([*c]u8, paths.len);
|
||||
defer self.allocator.free(results);
|
||||
|
||||
// Call native batch hash
|
||||
const ret = c.fh_hash_batch(self.ctx, c_paths.ptr, @intCast(paths.len), results.ptr);
|
||||
if (ret != 0) return error.HashFailed;
|
||||
|
||||
// Convert results to Zig strings
|
||||
var hashes = try self.allocator.alloc([]const u8, paths.len);
|
||||
errdefer {
|
||||
for (hashes) |h| self.allocator.free(h);
|
||||
self.allocator.free(hashes);
|
||||
}
|
||||
|
||||
for (results, 0..) |r, i| {
|
||||
hashes[i] = try self.allocator.dupe(u8, std.mem.span(r));
|
||||
c.fh_free_string(r);
|
||||
}
|
||||
|
||||
return hashes;
|
||||
}
|
||||
|
||||
/// Hash entire directory (combined hash)
|
||||
pub fn hashDirectory(self: *NativeHasher, dir_path: []const u8) ![]const u8 {
|
||||
const c_path = try self.allocator.dupeZ(u8, dir_path);
|
||||
defer self.allocator.free(c_path);
|
||||
|
||||
const result = c.fh_hash_directory(self.ctx, c_path.ptr);
|
||||
if (result == null) return error.HashFailed;
|
||||
defer c.fh_free_string(result);
|
||||
|
||||
return try self.allocator.dupe(u8, std.mem.span(result));
|
||||
}
|
||||
|
||||
/// Hash directory with batch output (individual file hashes)
|
||||
pub fn hashDirectoryBatch(
|
||||
self: *NativeHasher,
|
||||
dir_path: []const u8,
|
||||
max_results: u32,
|
||||
) !struct { hashes: [][]const u8, paths: [][]const u8, count: u32 } {
|
||||
const c_path = try self.allocator.dupeZ(u8, dir_path);
|
||||
defer self.allocator.free(c_path);
|
||||
|
||||
// Allocate output arrays
|
||||
const hashes = try self.allocator.alloc([*c]u8, max_results);
|
||||
defer self.allocator.free(hashes);
|
||||
|
||||
const paths = try self.allocator.alloc([*c]u8, max_results);
|
||||
defer self.allocator.free(paths);
|
||||
|
||||
var count: u32 = 0;
|
||||
|
||||
const ret = c.fh_hash_directory_batch(
|
||||
self.ctx,
|
||||
c_path.ptr,
|
||||
hashes.ptr,
|
||||
paths.ptr,
|
||||
max_results,
|
||||
&count,
|
||||
);
|
||||
if (ret != 0) return error.HashFailed;
|
||||
|
||||
// Convert to Zig arrays
|
||||
var zig_hashes = try self.allocator.alloc([]const u8, count);
|
||||
errdefer {
|
||||
for (zig_hashes) |h| self.allocator.free(h);
|
||||
self.allocator.free(zig_hashes);
|
||||
}
|
||||
|
||||
var zig_paths = try self.allocator.alloc([]const u8, count);
|
||||
errdefer {
|
||||
for (zig_paths) |p| self.allocator.free(p);
|
||||
self.allocator.free(zig_paths);
|
||||
}
|
||||
|
||||
for (0..count) |i| {
|
||||
zig_hashes[i] = try self.allocator.dupe(u8, std.mem.span(hashes[i]));
|
||||
c.fh_free_string(hashes[i]);
|
||||
|
||||
zig_paths[i] = try self.allocator.dupe(u8, std.mem.span(paths[i]));
|
||||
c.fh_free_string(paths[i]);
|
||||
}
|
||||
|
||||
return .{
|
||||
.hashes = zig_hashes,
|
||||
.paths = zig_paths,
|
||||
.count = count,
|
||||
};
|
||||
}
|
||||
|
||||
/// Check if SIMD SHA-256 is available
|
||||
pub fn hasSimd(self: *NativeHasher) bool {
|
||||
_ = self;
|
||||
return c.fh_has_simd_sha256() != 0;
|
||||
}
|
||||
|
||||
/// Get implementation info (SIMD type, etc.)
|
||||
pub fn getImplInfo(self: *NativeHasher) []const u8 {
|
||||
_ = self;
|
||||
return std.mem.span(c.fh_get_simd_impl_name());
|
||||
}
|
||||
};
|
||||
|
||||
/// Convenience function: hash directory using native library
|
||||
pub fn hashDirectoryNative(allocator: std.mem.Allocator, dir_path: []const u8) ![]const u8 {
|
||||
var hasher = try NativeHasher.init(allocator, 0); // Auto-detect threads
|
||||
defer hasher.deinit();
|
||||
return try hasher.hashDirectory(dir_path);
|
||||
}
|
||||
|
||||
/// Convenience function: batch hash files using native library
|
||||
pub fn hashFilesNative(
|
||||
allocator: std.mem.Allocator,
|
||||
paths: []const []const u8,
|
||||
) ![][]const u8 {
|
||||
var hasher = try NativeHasher.init(allocator, 0);
|
||||
defer hasher.deinit();
|
||||
return try hasher.hashBatch(paths);
|
||||
}
|
||||
|
||||
test "NativeHasher basic operations" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
// Skip if native library not available
|
||||
var hasher = NativeHasher.init(allocator, 1) catch |err| {
|
||||
if (err == error.NativeInitFailed) {
|
||||
std.debug.print("Native library not available, skipping test\n", .{});
|
||||
return;
|
||||
}
|
||||
return err;
|
||||
};
|
||||
defer hasher.deinit();
|
||||
|
||||
// Check SIMD availability
|
||||
const has_simd = hasher.hasSimd();
|
||||
const impl_name = hasher.getImplInfo();
|
||||
std.debug.print("SIMD: {any}, Impl: {s}\n", .{ has_simd, impl_name });
|
||||
}
|
||||
231
cli/src/utils/parallel_walk.zig
Normal file
231
cli/src/utils/parallel_walk.zig
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
const std = @import("std");
|
||||
const ignore = @import("ignore.zig");
|
||||
|
||||
/// Thread-safe work queue for parallel directory walking
|
||||
const WorkQueue = struct {
|
||||
items: std.ArrayList(WorkItem),
|
||||
mutex: std.Thread.Mutex,
|
||||
condition: std.Thread.Condition,
|
||||
done: bool,
|
||||
|
||||
const WorkItem = struct {
|
||||
path: []const u8,
|
||||
depth: usize,
|
||||
};
|
||||
|
||||
fn init(allocator: std.mem.Allocator) WorkQueue {
|
||||
return .{
|
||||
.items = .empty,
|
||||
.mutex = .{},
|
||||
.condition = .{},
|
||||
.done = false,
|
||||
};
|
||||
}
|
||||
|
||||
fn deinit(self: *WorkQueue, allocator: std.mem.Allocator) void {
|
||||
for (self.items.items) |item| {
|
||||
allocator.free(item.path);
|
||||
}
|
||||
self.items.deinit();
|
||||
}
|
||||
|
||||
fn push(self: *WorkQueue, path: []const u8, depth: usize, allocator: std.mem.Allocator) !void {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
|
||||
try self.items.append(.{
|
||||
.path = try allocator.dupe(u8, path),
|
||||
.depth = depth,
|
||||
});
|
||||
self.condition.signal();
|
||||
}
|
||||
|
||||
fn pop(self: *WorkQueue) ?WorkItem {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
|
||||
while (self.items.items.len == 0 and !self.done) {
|
||||
self.condition.wait(&self.mutex);
|
||||
}
|
||||
|
||||
if (self.items.items.len == 0) return null;
|
||||
return self.items.pop();
|
||||
}
|
||||
|
||||
fn setDone(self: *WorkQueue) void {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
self.done = true;
|
||||
self.condition.broadcast();
|
||||
}
|
||||
};
|
||||
|
||||
/// Result from parallel directory walk
|
||||
const WalkResult = struct {
|
||||
files: std.ArrayList([]const u8),
|
||||
mutex: std.Thread.Mutex,
|
||||
|
||||
fn init(allocator: std.mem.Allocator) WalkResult {
|
||||
return .{
|
||||
.files = std.ArrayList([]const u8).init(allocator),
|
||||
.mutex = .{},
|
||||
};
|
||||
}
|
||||
|
||||
fn deinit(self: *WalkResult, allocator: std.mem.Allocator) void {
|
||||
for (self.files.items) |file| {
|
||||
allocator.free(file);
|
||||
}
|
||||
self.files.deinit();
|
||||
}
|
||||
|
||||
fn add(self: *WalkResult, path: []const u8, allocator: std.mem.Allocator) !void {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
try self.files.append(try allocator.dupe(u8, path));
|
||||
}
|
||||
};
|
||||
|
||||
/// Thread context for parallel walking
|
||||
const ThreadContext = struct {
|
||||
queue: *WorkQueue,
|
||||
result: *WalkResult,
|
||||
gitignore: *ignore.GitIgnore,
|
||||
base_path: []const u8,
|
||||
allocator: std.mem.Allocator,
|
||||
max_depth: usize,
|
||||
};
|
||||
|
||||
/// Worker thread function for parallel directory walking
|
||||
fn walkWorker(ctx: *ThreadContext) void {
|
||||
while (true) {
|
||||
const item = ctx.queue.pop() orelse break;
|
||||
defer ctx.allocator.free(item.path);
|
||||
|
||||
if (item.depth >= ctx.max_depth) continue;
|
||||
|
||||
walkDirectoryParallel(ctx, item.path, item.depth) catch |err| {
|
||||
std.log.warn("Error walking {s}: {any}", .{ item.path, err });
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Walk a single directory and add subdirectories to queue
|
||||
fn walkDirectoryParallel(ctx: *ThreadContext, dir_path: []const u8, depth: usize) !void {
|
||||
const full_path = if (std.mem.eql(u8, dir_path, "."))
|
||||
ctx.base_path
|
||||
else
|
||||
try std.fs.path.join(ctx.allocator, &[_][]const u8{ ctx.base_path, dir_path });
|
||||
defer if (!std.mem.eql(u8, dir_path, ".")) ctx.allocator.free(full_path);
|
||||
|
||||
var dir = std.fs.cwd().openDir(full_path, .{ .iterate = true }) catch |err| switch (err) {
|
||||
error.AccessDenied => return,
|
||||
error.FileNotFound => return,
|
||||
else => return err,
|
||||
};
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
while (true) {
|
||||
const entry = it.next() catch |err| switch (err) {
|
||||
error.AccessDenied => continue,
|
||||
else => return err,
|
||||
} orelse break;
|
||||
|
||||
const entry_path = if (std.mem.eql(u8, dir_path, "."))
|
||||
try std.fmt.allocPrint(ctx.allocator, "{s}", .{entry.name})
|
||||
else
|
||||
try std.fs.path.join(ctx.allocator, &[_][]const u8{ dir_path, entry.name });
|
||||
defer ctx.allocator.free(entry_path);
|
||||
|
||||
// Check default ignores
|
||||
if (ignore.matchesDefaultIgnore(entry_path)) continue;
|
||||
|
||||
// Check gitignore patterns
|
||||
const is_dir = entry.kind == .directory;
|
||||
if (ctx.gitignore.isIgnored(entry_path, is_dir)) continue;
|
||||
|
||||
if (is_dir) {
|
||||
// Add subdirectory to work queue
|
||||
ctx.queue.push(entry_path, depth + 1, ctx.allocator) catch |err| {
|
||||
std.log.warn("Failed to queue {s}: {any}", .{ entry_path, err });
|
||||
};
|
||||
} else if (entry.kind == .file) {
|
||||
// Add file to results
|
||||
ctx.result.add(entry_path, ctx.allocator) catch |err| {
|
||||
std.log.warn("Failed to add file {s}: {any}", .{ entry_path, err });
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel directory walker that uses multiple threads
|
||||
pub fn parallelWalk(
|
||||
allocator: std.mem.Allocator,
|
||||
base_path: []const u8,
|
||||
gitignore: *ignore.GitIgnore,
|
||||
num_threads: usize,
|
||||
) ![][]const u8 {
|
||||
var queue = WorkQueue.init(allocator);
|
||||
defer queue.deinit(allocator);
|
||||
|
||||
var result = WalkResult.init(allocator);
|
||||
defer result.deinit(allocator);
|
||||
|
||||
// Start with base directory
|
||||
try queue.push(".", 0, allocator);
|
||||
|
||||
// Create thread context
|
||||
var ctx = ThreadContext{
|
||||
.queue = &queue,
|
||||
.result = &result,
|
||||
.gitignore = gitignore,
|
||||
.base_path = base_path,
|
||||
.allocator = allocator,
|
||||
.max_depth = 100, // Prevent infinite recursion
|
||||
};
|
||||
|
||||
// Spawn worker threads
|
||||
var threads = try allocator.alloc(std.Thread, num_threads);
|
||||
defer allocator.free(threads);
|
||||
|
||||
for (0..num_threads) |i| {
|
||||
threads[i] = try std.Thread.spawn(.{}, walkWorker, .{&ctx});
|
||||
}
|
||||
|
||||
// Wait for all workers to complete
|
||||
for (threads) |thread| {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
// Sort results for deterministic ordering
|
||||
std.sort.block([]const u8, result.files.items, {}, struct {
|
||||
fn lessThan(_: void, a: []const u8, b: []const u8) bool {
|
||||
return std.mem.order(u8, a, b) == .lt;
|
||||
}
|
||||
}.lessThan);
|
||||
|
||||
// Transfer ownership to caller
|
||||
const files = try allocator.alloc([]const u8, result.files.items.len);
|
||||
@memcpy(files, result.files.items);
|
||||
result.files.items.len = 0; // Prevent deinit from freeing
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
test "parallelWalk basic" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var gitignore = ignore.GitIgnore.init(allocator);
|
||||
defer gitignore.deinit();
|
||||
|
||||
// Walk the current directory with 4 threads
|
||||
const files = try parallelWalk(allocator, ".", &gitignore, 4);
|
||||
defer {
|
||||
for (files) |f| allocator.free(f);
|
||||
allocator.free(files);
|
||||
}
|
||||
|
||||
// Should find at least some files
|
||||
try std.testing.expect(files.len > 0);
|
||||
}
|
||||
278
cli/src/utils/pii.zig
Normal file
278
cli/src/utils/pii.zig
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// PII detection patterns for research data privacy
|
||||
pub const PIIPatterns = struct {
|
||||
email: std.regex.Regex,
|
||||
ssn: std.regex.Regex,
|
||||
phone: std.regex.Regex,
|
||||
credit_card: std.regex.Regex,
|
||||
ip_address: std.regex.Regex,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) !PIIPatterns {
|
||||
return PIIPatterns{
|
||||
.email = try std.regex.Regex.compile(allocator, "\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b"),
|
||||
.ssn = try std.regex.Regex.compile(allocator, "\\b\\d{3}-\\d{2}-\\d{4}\\b"),
|
||||
.phone = try std.regex.Regex.compile(allocator, "\\b\\d{3}-\\d{3}-\\d{4}\\b"),
|
||||
.credit_card = try std.regex.Regex.compile(allocator, "\\b(?:\\d[ -]*?){13,16}\\b"),
|
||||
.ip_address = try std.regex.Regex.compile(allocator, "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b"),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *PIIPatterns) void {
|
||||
self.email.deinit();
|
||||
self.ssn.deinit();
|
||||
self.phone.deinit();
|
||||
self.credit_card.deinit();
|
||||
self.ip_address.deinit();
|
||||
}
|
||||
};
|
||||
|
||||
/// A single PII finding
|
||||
pub const PIIFinding = struct {
|
||||
pii_type: []const u8,
|
||||
start_pos: usize,
|
||||
end_pos: usize,
|
||||
matched_text: []const u8,
|
||||
|
||||
pub fn format(self: PIIFinding, allocator: std.mem.Allocator) ![]u8 {
|
||||
return std.fmt.allocPrint(allocator, "{s} at position {d}: '{s}'", .{
|
||||
self.pii_type, self.start_pos, self.matched_text,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Detect PII in text - simplified version without regex
|
||||
pub fn detectPIISimple(text: []const u8, allocator: std.mem.Allocator) ![]PIIFinding {
|
||||
var findings = std.ArrayList(PIIFinding).initCapacity(allocator, 10) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer findings.deinit(allocator);
|
||||
|
||||
// Check for email patterns (@ symbol with surrounding text)
|
||||
var i: usize = 0;
|
||||
while (i < text.len) : (i += 1) {
|
||||
if (text[i] == '@') {
|
||||
// Look backwards for email start
|
||||
var start = i;
|
||||
while (start > 0 and isEmailChar(text[start - 1])) {
|
||||
start -= 1;
|
||||
}
|
||||
|
||||
// Look forwards for email end
|
||||
var end = i + 1;
|
||||
while (end < text.len and isEmailChar(text[end])) {
|
||||
end += 1;
|
||||
}
|
||||
|
||||
// Check if it looks like an email (has . after @)
|
||||
if (i > start and end > i + 1) {
|
||||
var has_dot = false;
|
||||
for (text[i + 1 .. end]) |c| {
|
||||
if (c == '.') {
|
||||
has_dot = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_dot) {
|
||||
try findings.append(allocator, PIIFinding{
|
||||
.pii_type = "email",
|
||||
.start_pos = start,
|
||||
.end_pos = end,
|
||||
.matched_text = text[start..end],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for IP addresses (simple pattern: XXX.XXX.XXX.XXX)
|
||||
i = 0;
|
||||
while (i < text.len) : (i += 1) {
|
||||
if (std.ascii.isDigit(text[i])) {
|
||||
var end = i;
|
||||
var dot_count: u8 = 0;
|
||||
var digit_count: u8 = 0;
|
||||
|
||||
while (end < text.len and (std.ascii.isDigit(text[end]) or text[end] == '.')) {
|
||||
if (text[end] == '.') {
|
||||
dot_count += 1;
|
||||
digit_count = 0;
|
||||
} else {
|
||||
digit_count += 1;
|
||||
if (digit_count > 3) break;
|
||||
}
|
||||
if (dot_count > 3) break;
|
||||
end += 1;
|
||||
}
|
||||
|
||||
// Check if it looks like an IP address (xxx.xxx.xxx.xxx pattern)
|
||||
if (dot_count == 3 and end - i >= 7 and end - i <= 15) {
|
||||
var valid = true;
|
||||
var num_start = i;
|
||||
var num_idx: u8 = 0;
|
||||
var nums: [4]u32 = undefined;
|
||||
|
||||
var idx: usize = 0;
|
||||
while (idx < end - i) : (idx += 1) {
|
||||
const c = text[i + idx];
|
||||
if (c == '.') {
|
||||
const num_str = text[num_start .. i + idx];
|
||||
nums[num_idx] = std.fmt.parseInt(u32, num_str, 10) catch {
|
||||
valid = false;
|
||||
break;
|
||||
};
|
||||
if (nums[num_idx] > 255) {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
num_idx += 1;
|
||||
num_start = i + idx + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse last number
|
||||
if (valid and num_idx == 3) {
|
||||
const num_str = text[num_start..end];
|
||||
if (std.fmt.parseInt(u32, num_str, 10)) |parsed_num| {
|
||||
nums[num_idx] = parsed_num;
|
||||
if (valid and nums[num_idx] <= 255) {
|
||||
try findings.append(allocator, PIIFinding{
|
||||
.pii_type = "ip_address",
|
||||
.start_pos = i,
|
||||
.end_pos = end,
|
||||
.matched_text = text[i..end],
|
||||
});
|
||||
}
|
||||
} else |_| {
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return findings.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn isEmailChar(c: u8) bool {
|
||||
return std.ascii.isAlphanumeric(c) or c == '.' or c == '_' or c == '%' or c == '+' or c == '-';
|
||||
}
|
||||
|
||||
/// Scan text and return warning if PII detected
|
||||
pub fn scanForPII(text: []const u8, allocator: std.mem.Allocator) !?[]const u8 {
|
||||
const findings = try detectPIISimple(text, allocator);
|
||||
defer allocator.free(findings);
|
||||
|
||||
if (findings.len == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
var warning = std.ArrayList(u8).initCapacity(allocator, 256) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer warning.deinit(allocator);
|
||||
|
||||
const writer = warning.writer(allocator);
|
||||
try writer.writeAll("Warning: Potential PII detected:\n");
|
||||
|
||||
for (findings) |finding| {
|
||||
try writer.print(" - {s}: '{s}'\n", .{ finding.pii_type, finding.matched_text });
|
||||
}
|
||||
|
||||
try writer.writeAll("Use --force to store anyway, or edit your text.");
|
||||
|
||||
return try warning.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
/// Redact PII from text for anonymized export
|
||||
pub fn redactPII(text: []const u8, allocator: std.mem.Allocator) ![]u8 {
|
||||
const findings = try detectPIISimple(text, allocator);
|
||||
defer allocator.free(findings);
|
||||
|
||||
if (findings.len == 0) {
|
||||
return allocator.dupe(u8, text);
|
||||
}
|
||||
|
||||
// Sort findings by position
|
||||
std.sort.sort(PIIFinding, findings, {}, compareByStartPos);
|
||||
|
||||
var result = std.ArrayList(u8).initCapacity(allocator, text.len) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer result.deinit(allocator);
|
||||
|
||||
var last_end: usize = 0;
|
||||
var redaction_counter: u32 = 0;
|
||||
|
||||
for (findings) |finding| {
|
||||
// Append text before this finding
|
||||
if (finding.start_pos > last_end) {
|
||||
try result.appendSlice(text[last_end..finding.start_pos]);
|
||||
}
|
||||
|
||||
// Append redaction placeholder
|
||||
redaction_counter += 1;
|
||||
if (std.mem.eql(u8, finding.pii_type, "email")) {
|
||||
try result.writer(allocator).print("[EMAIL-{d}]", .{redaction_counter});
|
||||
} else if (std.mem.eql(u8, finding.pii_type, "ip_address")) {
|
||||
try result.writer(allocator).print("[IP-{d}]", .{redaction_counter});
|
||||
} else {
|
||||
try result.writer(allocator).print("[REDACTED-{d}]", .{redaction_counter});
|
||||
}
|
||||
|
||||
last_end = finding.end_pos;
|
||||
}
|
||||
|
||||
// Append remaining text
|
||||
if (last_end < text.len) {
|
||||
try result.appendSlice(text[last_end..]);
|
||||
}
|
||||
|
||||
return result.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
fn compareByStartPos(_: void, a: PIIFinding, b: PIIFinding) bool {
|
||||
return a.start_pos < b.start_pos;
|
||||
}
|
||||
|
||||
/// Format findings as JSON for API responses
|
||||
pub fn formatFindingsAsJson(findings: []const PIIFinding, allocator: std.mem.Allocator) ![]u8 {
|
||||
var buf = std.ArrayList(u8).initCapacity(allocator, 1024) catch |err| {
|
||||
return err;
|
||||
};
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.writeAll("[");
|
||||
|
||||
for (findings, 0..) |finding, idx| {
|
||||
if (idx > 0) try writer.writeAll(",");
|
||||
try writer.writeAll("{");
|
||||
try writer.print("\"type\":\"{s}\",", .{finding.pii_type});
|
||||
try writer.print("\"start\":{d},", .{finding.start_pos});
|
||||
try writer.print("\"end\":{d},", .{finding.end_pos});
|
||||
try writer.writeAll("\"matched\":\"");
|
||||
// Escape the matched text
|
||||
for (finding.matched_text) |c| {
|
||||
switch (c) {
|
||||
'"' => try writer.writeAll("\\\""),
|
||||
'\\' => try writer.writeAll("\\\\"),
|
||||
'\n' => try writer.writeAll("\\n"),
|
||||
'\r' => try writer.writeAll("\\r"),
|
||||
'\t' => try writer.writeAll("\\t"),
|
||||
else => {
|
||||
if (c < 0x20) {
|
||||
try writer.print("\\u00{x:0>2}", .{c});
|
||||
} else {
|
||||
try writer.writeByte(c);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
try writer.writeAll("\"");
|
||||
try writer.writeAll("}");
|
||||
}
|
||||
|
||||
try writer.writeAll("]");
|
||||
return buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
|
@ -123,7 +123,7 @@ fn isNativeForTarget(data: []const u8) bool {
|
|||
/// 1. Download or build a static rsync binary for your target platform
|
||||
/// 2. Place it at cli/src/assets/rsync_release.bin
|
||||
/// 3. Build with: zig build prod (or release/cross targets)
|
||||
const placeholder_data = @embedFile("../assets/rsync_placeholder.bin");
|
||||
const placeholder_data = @embedFile("../assets/rsync/rsync_placeholder.bin");
|
||||
|
||||
const release_data = if (build_options.has_rsync_release)
|
||||
@embedFile(build_options.rsync_release_path)
|
||||
|
|
|
|||
36
cli/src/utils/sqlite_embedded.zig
Normal file
36
cli/src/utils/sqlite_embedded.zig
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
const std = @import("std");
|
||||
const build_options = @import("build_options");
|
||||
|
||||
/// SQLite embedding strategy (mirrors rsync pattern)
|
||||
///
|
||||
/// For dev builds: link against system SQLite library
|
||||
/// For release builds: compile SQLite from downloaded amalgamation
|
||||
///
|
||||
/// To prepare for release:
|
||||
/// 1. Run: make build-sqlite
|
||||
/// 2. Build with: zig build prod
|
||||
|
||||
pub const USE_EMBEDDED_SQLITE = build_options.has_sqlite_release;
|
||||
|
||||
/// Compile flags for embedded SQLite
|
||||
pub const SQLITE_FLAGS = &[_][]const u8{
|
||||
"-DSQLITE_ENABLE_FTS5",
|
||||
"-DSQLITE_ENABLE_JSON1",
|
||||
"-DSQLITE_THREADSAFE=1",
|
||||
"-DSQLITE_USE_URI",
|
||||
"-DSQLITE_ENABLE_COLUMN_METADATA",
|
||||
"-DSQLITE_ENABLE_STAT4",
|
||||
};
|
||||
|
||||
/// Get SQLite include path for embedded builds
|
||||
pub fn getSqliteIncludePath(b: *std.Build) ?std.Build.LazyPath {
|
||||
if (!USE_EMBEDDED_SQLITE) return null;
|
||||
return b.path(build_options.sqlite_release_path);
|
||||
}
|
||||
|
||||
/// Get SQLite source file path for embedded builds
|
||||
pub fn getSqliteSourcePath(b: *std.Build) ?std.Build.LazyPath {
|
||||
if (!USE_EMBEDDED_SQLITE) return null;
|
||||
const path = std.fs.path.join(b.allocator, &.{ build_options.sqlite_release_path, "sqlite3.c" }) catch return null;
|
||||
return b.path(path);
|
||||
}
|
||||
264
cli/src/utils/suggest.zig
Normal file
264
cli/src/utils/suggest.zig
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// Calculate Levenshtein distance between two strings
|
||||
pub fn levenshteinDistance(allocator: std.mem.Allocator, s1: []const u8, s2: []const u8) !usize {
|
||||
const m = s1.len + 1;
|
||||
const n = s2.len + 1;
|
||||
|
||||
// Create a 2D array for dynamic programming
|
||||
var dp = try allocator.alloc(usize, m * n);
|
||||
defer allocator.free(dp);
|
||||
|
||||
// Initialize first row and column
|
||||
for (0..m) |i| {
|
||||
dp[i * n] = i;
|
||||
}
|
||||
for (0..n) |j| {
|
||||
dp[j] = j;
|
||||
}
|
||||
|
||||
// Fill the matrix
|
||||
for (1..m) |i| {
|
||||
for (1..n) |j| {
|
||||
const cost: usize = if (s1[i - 1] == s2[j - 1]) 0 else 1;
|
||||
const deletion = dp[(i - 1) * n + j] + 1;
|
||||
const insertion = dp[i * n + (j - 1)] + 1;
|
||||
const substitution = dp[(i - 1) * n + (j - 1)] + cost;
|
||||
dp[i * n + j] = @min(@min(deletion, insertion), substitution);
|
||||
}
|
||||
}
|
||||
|
||||
return dp[(m - 1) * n + (n - 1)];
|
||||
}
|
||||
|
||||
/// Find suggestions for a typo from a list of candidates
|
||||
pub fn findSuggestions(
|
||||
allocator: std.mem.Allocator,
|
||||
input: []const u8,
|
||||
candidates: []const []const u8,
|
||||
max_distance: usize,
|
||||
max_suggestions: usize,
|
||||
) ![][]const u8 {
|
||||
var suggestions = std.ArrayList([]const u8).empty;
|
||||
defer suggestions.deinit(allocator);
|
||||
|
||||
var distances = std.ArrayList(usize).empty;
|
||||
defer distances.deinit(allocator);
|
||||
|
||||
for (candidates) |candidate| {
|
||||
const dist = try levenshteinDistance(allocator, input, candidate);
|
||||
if (dist <= max_distance) {
|
||||
try suggestions.append(allocator, candidate);
|
||||
try distances.append(allocator, dist);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by distance (bubble sort for simplicity with small lists)
|
||||
const n = distances.items.len;
|
||||
for (0..n) |i| {
|
||||
for (0..n - i - 1) |j| {
|
||||
if (distances.items[j] > distances.items[j + 1]) {
|
||||
// Swap distances
|
||||
const temp_dist = distances.items[j];
|
||||
distances.items[j] = distances.items[j + 1];
|
||||
distances.items[j + 1] = temp_dist;
|
||||
// Swap corresponding suggestions
|
||||
const temp_sugg = suggestions.items[j];
|
||||
suggestions.items[j] = suggestions.items[j + 1];
|
||||
suggestions.items[j + 1] = temp_sugg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return top suggestions
|
||||
const count = @min(suggestions.items.len, max_suggestions);
|
||||
const result = try allocator.alloc([]const u8, count);
|
||||
for (0..count) |i| {
|
||||
result[i] = try allocator.dupe(u8, suggestions.items[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Suggest commands based on prefix matching
|
||||
pub fn suggestCommands(input: []const u8) ?[]const []const u8 {
|
||||
const all_commands = [_][]const u8{
|
||||
"init", "sync", "queue", "requeue", "status",
|
||||
"monitor", "cancel", "prune", "watch", "dataset",
|
||||
"experiment", "narrative", "outcome", "info", "logs",
|
||||
"annotate", "validate", "compare", "find", "export",
|
||||
};
|
||||
|
||||
// Exact match - no suggestion needed
|
||||
for (all_commands) |cmd| {
|
||||
if (std.mem.eql(u8, input, cmd)) return null;
|
||||
}
|
||||
|
||||
// Find prefix matches
|
||||
var matches: [5][]const u8 = undefined;
|
||||
var match_count: usize = 0;
|
||||
|
||||
for (all_commands) |cmd| {
|
||||
if (std.mem.startsWith(u8, cmd, input)) {
|
||||
matches[match_count] = cmd;
|
||||
match_count += 1;
|
||||
if (match_count >= 5) break;
|
||||
}
|
||||
}
|
||||
|
||||
if (match_count == 0) return null;
|
||||
|
||||
// Return static slice - caller must not free
|
||||
return matches[0..match_count];
|
||||
}
|
||||
|
||||
/// Suggest flags for a command
|
||||
pub fn suggestFlags(command: []const u8, input: []const u8) ?[]const []const u8 {
|
||||
// Common flags for all commands
|
||||
const common_flags = [_][]const u8{ "--help", "--verbose", "--quiet", "--json" };
|
||||
|
||||
// Command-specific flags
|
||||
const queue_flags = [_][]const u8{
|
||||
"--commit", "--priority", "--cpu", "--memory", "--gpu",
|
||||
"--gpu-memory", "--hypothesis", "--context", "--intent", "--expected-outcome",
|
||||
"--experiment-group", "--tags", "--dry-run", "--validate", "--explain",
|
||||
"--force",
|
||||
};
|
||||
|
||||
const find_flags = [_][]const u8{
|
||||
"--tag", "--outcome", "--dataset", "--experiment-group",
|
||||
"--author", "--after", "--before", "--limit",
|
||||
};
|
||||
|
||||
const compare_flags = [_][]const u8{
|
||||
"--json", "--all", "--fields",
|
||||
};
|
||||
|
||||
const export_flags = [_][]const u8{
|
||||
"--bundle", "--anonymize", "--anonymize-level", "--base",
|
||||
};
|
||||
|
||||
// Select flags based on command
|
||||
const flags: []const []const u8 = switch (std.meta.stringToEnum(Command, command) orelse .unknown) {
|
||||
.queue => &queue_flags,
|
||||
.find => &find_flags,
|
||||
.compare => &compare_flags,
|
||||
.export_cmd => &export_flags,
|
||||
else => &common_flags,
|
||||
};
|
||||
|
||||
// Find prefix matches
|
||||
var matches: [5][]const u8 = undefined;
|
||||
var match_count: usize = 0;
|
||||
|
||||
// Check common flags first
|
||||
for (common_flags) |flag| {
|
||||
if (std.mem.startsWith(u8, flag, input)) {
|
||||
matches[match_count] = flag;
|
||||
match_count += 1;
|
||||
if (match_count >= 5) break;
|
||||
}
|
||||
}
|
||||
|
||||
// Then check command-specific flags
|
||||
if (match_count < 5) {
|
||||
for (flags) |flag| {
|
||||
if (std.mem.startsWith(u8, flag, input)) {
|
||||
// Avoid duplicates
|
||||
var already_added = false;
|
||||
for (0..match_count) |i| {
|
||||
if (std.mem.eql(u8, matches[i], flag)) {
|
||||
already_added = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!already_added) {
|
||||
matches[match_count] = flag;
|
||||
match_count += 1;
|
||||
if (match_count >= 5) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (match_count == 0) return null;
|
||||
return matches[0..match_count];
|
||||
}
|
||||
|
||||
const Command = enum {
|
||||
init,
|
||||
sync,
|
||||
queue,
|
||||
requeue,
|
||||
status,
|
||||
monitor,
|
||||
cancel,
|
||||
prune,
|
||||
watch,
|
||||
dataset,
|
||||
experiment,
|
||||
narrative,
|
||||
outcome,
|
||||
info,
|
||||
logs,
|
||||
annotate,
|
||||
validate,
|
||||
compare,
|
||||
find,
|
||||
export_cmd,
|
||||
unknown,
|
||||
};
|
||||
|
||||
/// Format suggestions into a helpful message
|
||||
pub fn formatSuggestionMessage(
|
||||
allocator: std.mem.Allocator,
|
||||
input: []const u8,
|
||||
suggestions: []const []const u8,
|
||||
) ![]u8 {
|
||||
if (suggestions.len == 0) return allocator.dupe(u8, "");
|
||||
|
||||
var buf = std.ArrayList(u8).empty;
|
||||
defer buf.deinit(allocator);
|
||||
|
||||
const writer = buf.writer(allocator);
|
||||
try writer.print("Did you mean for '{s}': ", .{input});
|
||||
|
||||
for (suggestions, 0..) |sugg, i| {
|
||||
if (i > 0) {
|
||||
if (i == suggestions.len - 1) {
|
||||
try writer.writeAll(" or ");
|
||||
} else {
|
||||
try writer.writeAll(", ");
|
||||
}
|
||||
}
|
||||
try writer.print("'{s}'", .{sugg});
|
||||
}
|
||||
|
||||
try writer.writeAll("?\n");
|
||||
|
||||
return buf.toOwnedSlice(allocator);
|
||||
}
|
||||
|
||||
/// Test the suggestion system
|
||||
pub fn testSuggestions() !void {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
// Test Levenshtein distance
|
||||
const dist1 = try levenshteinDistance(allocator, "queue", "quee");
|
||||
std.debug.assert(dist1 == 1);
|
||||
|
||||
const dist2 = try levenshteinDistance(allocator, "status", "statis");
|
||||
std.debug.assert(dist2 == 1);
|
||||
|
||||
// Test suggestions
|
||||
const candidates = [_][]const u8{ "queue", "query", "quiet", "quit" };
|
||||
const suggestions = try findSuggestions(allocator, "quee", &candidates, 2, 3);
|
||||
defer {
|
||||
for (suggestions) |s| allocator.free(s);
|
||||
allocator.free(suggestions);
|
||||
}
|
||||
std.debug.assert(suggestions.len > 0);
|
||||
std.debug.assert(std.mem.eql(u8, suggestions[0], "queue"));
|
||||
|
||||
std.debug.print("Suggestion tests passed!\n", .{});
|
||||
}
|
||||
44
cli/src/utils/uuid.zig
Normal file
44
cli/src/utils/uuid.zig
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
const std = @import("std");
|
||||
|
||||
/// UUID v4 generator - generates random UUIDs
|
||||
pub fn generateV4(allocator: std.mem.Allocator) ![]const u8 {
|
||||
var bytes: [16]u8 = undefined;
|
||||
std.crypto.random.bytes(&bytes);
|
||||
|
||||
// Set version (4) and variant bits
|
||||
bytes[6] = (bytes[6] & 0x0F) | 0x40; // Version 4
|
||||
bytes[8] = (bytes[8] & 0x3F) | 0x80; // Variant 10
|
||||
|
||||
// Format as string: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
const uuid_str = try allocator.alloc(u8, 36);
|
||||
|
||||
const hex_chars = "0123456789abcdef";
|
||||
|
||||
var i: usize = 0;
|
||||
var j: usize = 0;
|
||||
while (i < 16) : (i += 1) {
|
||||
uuid_str[j] = hex_chars[bytes[i] >> 4];
|
||||
uuid_str[j + 1] = hex_chars[bytes[i] & 0x0F];
|
||||
j += 2;
|
||||
|
||||
// Add dashes at positions 8, 12, 16, 20
|
||||
if (i == 3 or i == 5 or i == 7 or i == 9) {
|
||||
uuid_str[j] = '-';
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
|
||||
return uuid_str;
|
||||
}
|
||||
|
||||
/// Generate a simple random ID (shorter than UUID, for internal use)
|
||||
pub fn generateSimpleID(allocator: std.mem.Allocator, length: usize) ![]const u8 {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyz0123456789";
|
||||
const id = try allocator.alloc(u8, length);
|
||||
|
||||
for (id) |*c| {
|
||||
c.* = chars[std.crypto.random.int(usize) % chars.len];
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
283
cli/src/utils/watch.zig
Normal file
283
cli/src/utils/watch.zig
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
const std = @import("std");
|
||||
const os = std.os;
|
||||
const log = std.log;
|
||||
|
||||
/// File watcher using OS-native APIs (kqueue on macOS, inotify on Linux)
|
||||
/// Zero third-party dependencies - uses only standard library OS bindings
|
||||
pub const FileWatcher = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
watched_paths: std.StringHashMap(void),
|
||||
|
||||
// Platform-specific handles
|
||||
kqueue_fd: if (@import("builtin").target.os.tag == .macos) i32 else void,
|
||||
inotify_fd: if (@import("builtin").target.os.tag == .linux) i32 else void,
|
||||
|
||||
// Debounce timer
|
||||
last_event_time: i64,
|
||||
debounce_ms: i64,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, debounce_ms: i64) !FileWatcher {
|
||||
const target = @import("builtin").target;
|
||||
|
||||
var watcher = FileWatcher{
|
||||
.allocator = allocator,
|
||||
.watched_paths = std.StringHashMap(void).init(allocator),
|
||||
.kqueue_fd = if (target.os.tag == .macos) -1 else {},
|
||||
.inotify_fd = if (target.os.tag == .linux) -1 else {},
|
||||
.last_event_time = 0,
|
||||
.debounce_ms = debounce_ms,
|
||||
};
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => {
|
||||
watcher.kqueue_fd = try os.kqueue();
|
||||
},
|
||||
.linux => {
|
||||
watcher.inotify_fd = try std.os.inotify_init1(os.linux.IN_CLOEXEC);
|
||||
},
|
||||
else => {
|
||||
return error.UnsupportedPlatform;
|
||||
},
|
||||
}
|
||||
|
||||
return watcher;
|
||||
}
|
||||
|
||||
pub fn deinit(self: *FileWatcher) void {
|
||||
const target = @import("builtin").target;
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => {
|
||||
if (self.kqueue_fd != -1) {
|
||||
os.close(self.kqueue_fd);
|
||||
}
|
||||
},
|
||||
.linux => {
|
||||
if (self.inotify_fd != -1) {
|
||||
os.close(self.inotify_fd);
|
||||
}
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
|
||||
var it = self.watched_paths.keyIterator();
|
||||
while (it.next()) |key| {
|
||||
self.allocator.free(key.*);
|
||||
}
|
||||
self.watched_paths.deinit();
|
||||
}
|
||||
|
||||
/// Add a directory to watch recursively
|
||||
pub fn watchDirectory(self: *FileWatcher, path: []const u8) !void {
|
||||
if (self.watched_paths.contains(path)) return;
|
||||
|
||||
const path_copy = try self.allocator.dupe(u8, path);
|
||||
try self.watched_paths.put(path_copy, {});
|
||||
|
||||
const target = @import("builtin").target;
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => try self.addKqueueWatch(path),
|
||||
.linux => try self.addInotifyWatch(path),
|
||||
else => return error.UnsupportedPlatform,
|
||||
}
|
||||
|
||||
// Recursively watch subdirectories
|
||||
var dir = std.fs.cwd().openDir(path, .{ .iterate = true }) catch |err| switch (err) {
|
||||
error.AccessDenied => return,
|
||||
error.FileNotFound => return,
|
||||
else => return err,
|
||||
};
|
||||
defer dir.close();
|
||||
|
||||
var it = dir.iterate();
|
||||
while (true) {
|
||||
const entry = it.next() catch |err| switch (err) {
|
||||
error.AccessDenied => continue,
|
||||
else => return err,
|
||||
} orelse break;
|
||||
|
||||
if (entry.kind == .directory) {
|
||||
const subpath = try std.fs.path.join(self.allocator, &[_][]const u8{ path, entry.name });
|
||||
defer self.allocator.free(subpath);
|
||||
try self.watchDirectory(subpath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add kqueue watch for macOS
|
||||
fn addKqueueWatch(self: *FileWatcher, path: []const u8) !void {
|
||||
if (@import("builtin").target.os.tag != .macos) return;
|
||||
|
||||
const fd = try os.open(path, os.O.EVTONLY | os.O_RDONLY, 0);
|
||||
defer os.close(fd);
|
||||
|
||||
const event = os.Kevent{
|
||||
.ident = @intCast(fd),
|
||||
.filter = os.EVFILT_VNODE,
|
||||
.flags = os.EV_ADD | os.EV_CLEAR,
|
||||
.fflags = os.NOTE_WRITE | os.NOTE_EXTEND | os.NOTE_ATTRIB | os.NOTE_LINK | os.NOTE_RENAME | os.NOTE_REVOKE,
|
||||
.data = 0,
|
||||
.udata = 0,
|
||||
};
|
||||
|
||||
const changes = [_]os.Kevent{event};
|
||||
_ = try os.kevent(self.kqueue_fd, &changes, &.{}, null);
|
||||
}
|
||||
|
||||
/// Add inotify watch for Linux
|
||||
fn addInotifyWatch(self: *FileWatcher, path: []const u8) !void {
|
||||
if (@import("builtin").target.os.tag != .linux) return;
|
||||
|
||||
const mask = os.linux.IN_MODIFY | os.linux.IN_CREATE | os.linux.IN_DELETE |
|
||||
os.linux.IN_MOVED_FROM | os.linux.IN_MOVED_TO | os.linux.IN_ATTRIB;
|
||||
|
||||
const wd = try os.linux.inotify_add_watch(self.inotify_fd, path.ptr, mask);
|
||||
if (wd < 0) return error.InotifyError;
|
||||
}
|
||||
|
||||
/// Wait for file changes with debouncing
|
||||
pub fn waitForChanges(self: *FileWatcher, timeout_ms: i32) !bool {
|
||||
const target = @import("builtin").target;
|
||||
|
||||
switch (target.os.tag) {
|
||||
.macos => return try self.waitKqueue(timeout_ms),
|
||||
.linux => return try self.waitInotify(timeout_ms),
|
||||
else => return error.UnsupportedPlatform,
|
||||
}
|
||||
}
|
||||
|
||||
/// kqueue wait implementation
|
||||
fn waitKqueue(self: *FileWatcher, timeout_ms: i32) !bool {
|
||||
if (@import("builtin").target.os.tag != .macos) return false;
|
||||
|
||||
var ts: os.timespec = undefined;
|
||||
if (timeout_ms >= 0) {
|
||||
ts.tv_sec = @divTrunc(timeout_ms, 1000);
|
||||
ts.tv_nsec = @mod(timeout_ms, 1000) * 1000000;
|
||||
}
|
||||
|
||||
var events: [10]os.Kevent = undefined;
|
||||
const nev = os.kevent(self.kqueue_fd, &.{}, &events, if (timeout_ms >= 0) &ts else null) catch |err| switch (err) {
|
||||
error.ETIME => return false, // Timeout
|
||||
else => return err,
|
||||
};
|
||||
|
||||
if (nev > 0) {
|
||||
const now = std.time.milliTimestamp();
|
||||
if (now - self.last_event_time > self.debounce_ms) {
|
||||
self.last_event_time = now;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// inotify wait implementation
|
||||
fn waitInotify(self: *FileWatcher, timeout_ms: i32) !bool {
|
||||
if (@import("builtin").target.os.tag != .linux) return false;
|
||||
|
||||
var fds = [_]os.pollfd{.{
|
||||
.fd = self.inotify_fd,
|
||||
.events = os.POLLIN,
|
||||
.revents = 0,
|
||||
}};
|
||||
|
||||
const ready = os.poll(&fds, timeout_ms) catch |err| switch (err) {
|
||||
error.ETIME => return false,
|
||||
else => return err,
|
||||
};
|
||||
|
||||
if (ready > 0 and (fds[0].revents & os.POLLIN) != 0) {
|
||||
var buf: [4096]u8 align(@alignOf(os.linux.inotify_event)) = undefined;
|
||||
|
||||
const bytes_read = try os.read(self.inotify_fd, &buf);
|
||||
if (bytes_read > 0) {
|
||||
const now = std.time.milliTimestamp();
|
||||
if (now - self.last_event_time > self.debounce_ms) {
|
||||
self.last_event_time = now;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Run watch loop with callback
|
||||
pub fn run(self: *FileWatcher, callback: fn () void) !void {
|
||||
log.info("Watching for file changes (debounce: {d}ms)...", .{self.debounce_ms});
|
||||
|
||||
while (true) {
|
||||
if (try self.waitForChanges(-1)) {
|
||||
callback();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Watch command handler
|
||||
pub fn watchCommand(allocator: std.mem.Allocator, args: []const []const u8) !void {
|
||||
if (args.len < 1) {
|
||||
std.debug.print("Usage: ml watch <path> [--sync] [--queue]\n", .{});
|
||||
return error.InvalidArgs;
|
||||
}
|
||||
|
||||
const path = args[0];
|
||||
var auto_sync = false;
|
||||
var auto_queue = false;
|
||||
|
||||
// Parse flags
|
||||
for (args[1..]) |arg| {
|
||||
if (std.mem.eql(u8, arg, "--sync")) auto_sync = true;
|
||||
if (std.mem.eql(u8, arg, "--queue")) auto_queue = true;
|
||||
}
|
||||
|
||||
var watcher = try FileWatcher.init(allocator, 100); // 100ms debounce
|
||||
defer watcher.deinit();
|
||||
|
||||
try watcher.watchDirectory(path);
|
||||
|
||||
log.info("Watching {s} for changes...", .{path});
|
||||
|
||||
// Callback for file changes
|
||||
const CallbackContext = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
path: []const u8,
|
||||
auto_sync: bool,
|
||||
auto_queue: bool,
|
||||
};
|
||||
|
||||
const ctx = CallbackContext{
|
||||
.allocator = allocator,
|
||||
.path = path,
|
||||
.auto_sync = auto_sync,
|
||||
.auto_queue = auto_queue,
|
||||
};
|
||||
|
||||
// Run watch loop
|
||||
while (true) {
|
||||
if (try watcher.waitForChanges(-1)) {
|
||||
log.info("File changes detected", .{});
|
||||
|
||||
if (auto_sync) {
|
||||
log.info("Auto-syncing...", .{});
|
||||
// Trigger sync (implementation would call sync command)
|
||||
_ = ctx;
|
||||
}
|
||||
|
||||
if (auto_queue) {
|
||||
log.info("Auto-queuing...", .{});
|
||||
// Trigger queue (implementation would call queue command)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test "FileWatcher init/deinit" {
|
||||
const allocator = std.testing.allocator;
|
||||
|
||||
var watcher = try FileWatcher.init(allocator, 100);
|
||||
defer watcher.deinit();
|
||||
}
|
||||
|
|
@ -21,6 +21,7 @@ func main() {
|
|||
apiKey := flag.String("api-key", "", "API key for authentication")
|
||||
showVersion := flag.Bool("version", false, "Show version and build info")
|
||||
verifyBuild := flag.Bool("verify", false, "Verify build integrity")
|
||||
securityAudit := flag.Bool("security-audit", false, "Run security audit and exit")
|
||||
flag.Parse()
|
||||
|
||||
// Handle version display
|
||||
|
|
@ -38,6 +39,12 @@ func main() {
|
|||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Handle security audit
|
||||
if *securityAudit {
|
||||
runSecurityAudit(*configFile)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Create and start server
|
||||
server, err := api.NewServer(*configFile)
|
||||
if err != nil {
|
||||
|
|
@ -54,3 +61,75 @@ func main() {
|
|||
// Reserved for future authentication enhancements
|
||||
_ = apiKey
|
||||
}
|
||||
|
||||
// runSecurityAudit performs security checks and reports issues
|
||||
func runSecurityAudit(configFile string) {
|
||||
fmt.Println("=== Security Audit ===")
|
||||
issues := []string{}
|
||||
warnings := []string{}
|
||||
|
||||
// Check 1: Config file permissions
|
||||
if info, err := os.Stat(configFile); err == nil {
|
||||
mode := info.Mode().Perm()
|
||||
if mode&0077 != 0 {
|
||||
issues = append(issues, fmt.Sprintf("Config file %s is world/group readable (permissions: %04o)", configFile, mode))
|
||||
} else {
|
||||
fmt.Printf("✓ Config file permissions: %04o (secure)\n", mode)
|
||||
}
|
||||
} else {
|
||||
warnings = append(warnings, fmt.Sprintf("Could not check config file: %v", err))
|
||||
}
|
||||
|
||||
// Check 2: Environment variable exposure
|
||||
sensitiveVars := []string{"JWT_SECRET", "FETCH_ML_API_KEY", "DATABASE_PASSWORD", "REDIS_PASSWORD"}
|
||||
exposedVars := []string{}
|
||||
for _, v := range sensitiveVars {
|
||||
if os.Getenv(v) != "" {
|
||||
exposedVars = append(exposedVars, v)
|
||||
}
|
||||
}
|
||||
if len(exposedVars) > 0 {
|
||||
warnings = append(warnings, fmt.Sprintf("Sensitive environment variables exposed: %v (will be cleared on startup)", exposedVars))
|
||||
} else {
|
||||
fmt.Println("✓ No sensitive environment variables exposed")
|
||||
}
|
||||
|
||||
// Check 3: Running as root
|
||||
if os.Getuid() == 0 {
|
||||
issues = append(issues, "Running as root (UID 0) - should run as non-root user")
|
||||
} else {
|
||||
fmt.Printf("✓ Running as non-root user (UID: %d)\n", os.Getuid())
|
||||
}
|
||||
|
||||
// Check 4: API key file permissions
|
||||
apiKeyFile := os.Getenv("FETCH_ML_API_KEY_FILE")
|
||||
if apiKeyFile != "" {
|
||||
if info, err := os.Stat(apiKeyFile); err == nil {
|
||||
mode := info.Mode().Perm()
|
||||
if mode&0077 != 0 {
|
||||
issues = append(issues, fmt.Sprintf("API key file %s is world/group readable (permissions: %04o)", apiKeyFile, mode))
|
||||
} else {
|
||||
fmt.Printf("✓ API key file permissions: %04o (secure)\n", mode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Report results
|
||||
fmt.Println()
|
||||
if len(issues) == 0 && len(warnings) == 0 {
|
||||
fmt.Println("✓ All security checks passed")
|
||||
} else {
|
||||
if len(issues) > 0 {
|
||||
fmt.Printf("✗ Found %d security issue(s):\n", len(issues))
|
||||
for _, issue := range issues {
|
||||
fmt.Printf(" - %s\n", issue)
|
||||
}
|
||||
}
|
||||
if len(warnings) > 0 {
|
||||
fmt.Printf("⚠ Found %d warning(s):\n", len(warnings))
|
||||
for _, warning := range warnings {
|
||||
fmt.Printf(" - %s\n", warning)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
54
cmd/gen-keys/main.go
Normal file
54
cmd/gen-keys/main.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
// Package main implements a tool for generating Ed25519 signing keys
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/crypto"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
outDir = flag.String("out", "./keys", "Output directory for keys")
|
||||
keyID = flag.String("key-id", "manifest-signer-1", "Key identifier")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
// Create output directory
|
||||
if err := os.MkdirAll(*outDir, 0700); err != nil {
|
||||
log.Fatalf("Failed to create output directory: %v", err)
|
||||
}
|
||||
|
||||
// Generate keypair
|
||||
publicKey, privateKey, err := crypto.GenerateSigningKeys()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate signing keys: %v", err)
|
||||
}
|
||||
|
||||
// Define paths
|
||||
privKeyPath := fmt.Sprintf("%s/%s_private.key", *outDir, *keyID)
|
||||
pubKeyPath := fmt.Sprintf("%s/%s_public.key", *outDir, *keyID)
|
||||
|
||||
// Save private key (restricted permissions)
|
||||
if err := crypto.SavePrivateKeyToFile(privateKey, privKeyPath); err != nil {
|
||||
log.Fatalf("Failed to save private key: %v", err)
|
||||
}
|
||||
|
||||
// Save public key
|
||||
if err := crypto.SavePublicKeyToFile(publicKey, pubKeyPath); err != nil {
|
||||
log.Fatalf("Failed to save public key: %v", err)
|
||||
}
|
||||
|
||||
// Print summary
|
||||
fmt.Printf("✓ Generated Ed25519 signing keys\n")
|
||||
fmt.Printf(" Key ID: %s\n", *keyID)
|
||||
fmt.Printf(" Private key: %s (permissions: 0600)\n", privKeyPath)
|
||||
fmt.Printf(" Public key: %s\n", pubKeyPath)
|
||||
fmt.Printf("\nImportant:\n")
|
||||
fmt.Printf(" - Store the private key securely (it can sign manifests)\n")
|
||||
fmt.Printf(" - Distribute the public key to verification systems\n")
|
||||
fmt.Printf(" - Set environment variable: FETCHML_SIGNING_KEY_PATH=%s\n", privKeyPath)
|
||||
}
|
||||
|
|
@ -356,7 +356,8 @@ api_key = "your_api_key_here" # Your API key (get from admin)
|
|||
|
||||
// Set proper permissions
|
||||
if err := auth.CheckConfigFilePermissions(configPath); err != nil {
|
||||
log.Printf("Warning: %v", err)
|
||||
// Log permission warning but don't fail
|
||||
_ = err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -17,12 +17,25 @@ type Config struct {
|
|||
SSHKey string `toml:"ssh_key"`
|
||||
Port int `toml:"port"`
|
||||
BasePath string `toml:"base_path"`
|
||||
Mode string `toml:"mode"` // "dev" or "prod"
|
||||
WrapperScript string `toml:"wrapper_script"`
|
||||
TrainScript string `toml:"train_script"`
|
||||
RedisAddr string `toml:"redis_addr"`
|
||||
RedisPassword string `toml:"redis_password"`
|
||||
RedisDB int `toml:"redis_db"`
|
||||
KnownHosts string `toml:"known_hosts"`
|
||||
ServerURL string `toml:"server_url"` // WebSocket server URL (e.g., ws://localhost:8080)
|
||||
|
||||
// Local mode configuration
|
||||
DBPath string `toml:"db_path"` // Path to SQLite database (local mode)
|
||||
ForceLocal bool `toml:"force_local"` // Force local-only mode
|
||||
ProjectRoot string `toml:"project_root"` // Project root for local mode
|
||||
|
||||
// Experiment configuration
|
||||
Experiment struct {
|
||||
Name string `toml:"name"`
|
||||
Entrypoint string `toml:"entrypoint"`
|
||||
} `toml:"experiment"`
|
||||
|
||||
// Authentication
|
||||
Auth auth.Config `toml:"auth"`
|
||||
|
|
@ -133,6 +146,20 @@ func (c *Config) Validate() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Set default mode if not specified
|
||||
if c.Mode == "" {
|
||||
if os.Getenv("FETCH_ML_TUI_MODE") != "" {
|
||||
c.Mode = os.Getenv("FETCH_ML_TUI_MODE")
|
||||
} else {
|
||||
c.Mode = "dev" // Default to dev mode
|
||||
}
|
||||
}
|
||||
|
||||
// Set mode-appropriate default paths using project-relative paths
|
||||
if c.BasePath == "" {
|
||||
c.BasePath = utils.ModeBasedBasePath(c.Mode)
|
||||
}
|
||||
|
||||
if c.BasePath != "" {
|
||||
// Convert relative paths to absolute
|
||||
c.BasePath = utils.ExpandPath(c.BasePath)
|
||||
|
|
@ -150,7 +177,29 @@ func (c *Config) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// PendingPath returns the path for pending experiments
|
||||
// IsLocalMode returns true if the TUI should operate in local-only mode
|
||||
func (c *Config) IsLocalMode() bool {
|
||||
if c.ForceLocal {
|
||||
return true
|
||||
}
|
||||
// Check if tracking_uri indicates local mode (sqlite:// prefix)
|
||||
if c.DBPath != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetDBPath returns the SQLite database path for local mode
|
||||
func (c *Config) GetDBPath() string {
|
||||
if c.DBPath != "" {
|
||||
return c.DBPath
|
||||
}
|
||||
// Default location: ~/.fetchml/experiments.db
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, ".fetchml", "experiments.db")
|
||||
}
|
||||
return "fetchml.db"
|
||||
}
|
||||
func (c *Config) PendingPath() string { return filepath.Join(c.BasePath, "pending") }
|
||||
|
||||
// RunningPath returns the path for running experiments
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@ package controller
|
|||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/internal/container"
|
||||
"github.com/jfraeys/fetch_ml/internal/worker"
|
||||
)
|
||||
|
||||
func shellQuote(s string) string {
|
||||
|
|
@ -24,6 +26,7 @@ func (c *Controller) loadAllData() tea.Cmd {
|
|||
c.loadQueue(),
|
||||
c.loadGPU(),
|
||||
c.loadContainer(),
|
||||
c.loadDatasets(),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -39,6 +42,13 @@ func (c *Controller) loadJobs() tea.Cmd {
|
|||
var jobs []model.Job
|
||||
statusChan := make(chan []model.Job, 4)
|
||||
|
||||
// Debug: Print paths being used
|
||||
c.logger.Info("Loading jobs from paths",
|
||||
"pending", c.getPathForStatus(model.StatusPending),
|
||||
"running", c.getPathForStatus(model.StatusRunning),
|
||||
"finished", c.getPathForStatus(model.StatusFinished),
|
||||
"failed", c.getPathForStatus(model.StatusFailed))
|
||||
|
||||
for _, status := range []model.JobStatus{
|
||||
model.StatusPending,
|
||||
model.StatusRunning,
|
||||
|
|
@ -48,22 +58,18 @@ func (c *Controller) loadJobs() tea.Cmd {
|
|||
go func(s model.JobStatus) {
|
||||
path := c.getPathForStatus(s)
|
||||
names := c.server.ListDir(path)
|
||||
|
||||
// Debug: Log what we found
|
||||
c.logger.Info("Listed directory", "status", s, "path", path, "count", len(names))
|
||||
|
||||
var statusJobs []model.Job
|
||||
for _, name := range names {
|
||||
jobStatus, _ := c.taskQueue.GetJobStatus(name)
|
||||
taskID := jobStatus["task_id"]
|
||||
priority := int64(0)
|
||||
if p, ok := jobStatus["priority"]; ok {
|
||||
_, err := fmt.Sscanf(p, "%d", &priority)
|
||||
if err != nil {
|
||||
priority = 0
|
||||
}
|
||||
}
|
||||
// Lazy loading: only fetch basic info for list view
|
||||
// Full details (GPU, narrative) loaded on selection
|
||||
statusJobs = append(statusJobs, model.Job{
|
||||
Name: name,
|
||||
Status: s,
|
||||
TaskID: taskID,
|
||||
Priority: priority,
|
||||
Name: name,
|
||||
Status: s,
|
||||
// TaskID, Priority, GPU info loaded lazily
|
||||
})
|
||||
}
|
||||
statusChan <- statusJobs
|
||||
|
|
@ -106,52 +112,51 @@ func (c *Controller) loadGPU() tea.Cmd {
|
|||
|
||||
resultChan := make(chan gpuResult, 1)
|
||||
go func() {
|
||||
cmd := "nvidia-smi --query-gpu=index,name,utilization.gpu," +
|
||||
"memory.used,memory.total,temperature.gpu --format=csv,noheader,nounits"
|
||||
out, err := c.server.Exec(cmd)
|
||||
if err == nil && strings.TrimSpace(out) != "" {
|
||||
var formatted strings.Builder
|
||||
formatted.WriteString("GPU Status\n")
|
||||
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
for _, line := range lines {
|
||||
parts := strings.Split(line, ", ")
|
||||
if len(parts) >= 6 {
|
||||
formatted.WriteString(fmt.Sprintf("🎮 GPU %s: %s\n", parts[0], parts[1]))
|
||||
formatted.WriteString(fmt.Sprintf(" Utilization: %s%%\n", parts[2]))
|
||||
formatted.WriteString(fmt.Sprintf(" Memory: %s/%s MB\n", parts[3], parts[4]))
|
||||
formatted.WriteString(fmt.Sprintf(" Temperature: %s°C\n\n", parts[5]))
|
||||
// Try NVML first for accurate GPU info (Linux/Windows with NVIDIA)
|
||||
if worker.IsNVMLAvailable() {
|
||||
gpus, err := worker.GetAllGPUInfo()
|
||||
if err == nil && len(gpus) > 0 {
|
||||
var formatted strings.Builder
|
||||
formatted.WriteString("GPU Status (NVML)\n")
|
||||
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
|
||||
for _, gpu := range gpus {
|
||||
formatted.WriteString(fmt.Sprintf("🎮 GPU %d: %s\n", gpu.Index, gpu.Name))
|
||||
formatted.WriteString(fmt.Sprintf(" Utilization: %d%%\n", gpu.Utilization))
|
||||
formatted.WriteString(fmt.Sprintf(" Memory: %d/%d MB\n",
|
||||
gpu.MemoryUsed/1024/1024, gpu.MemoryTotal/1024/1024))
|
||||
formatted.WriteString(fmt.Sprintf(" Temperature: %d°C\n", gpu.Temperature))
|
||||
if gpu.PowerDraw > 0 {
|
||||
formatted.WriteString(fmt.Sprintf(" Power: %.1f W\n", float64(gpu.PowerDraw)/1000.0))
|
||||
}
|
||||
if gpu.ClockSM > 0 {
|
||||
formatted.WriteString(fmt.Sprintf(" SM Clock: %d MHz\n", gpu.ClockSM))
|
||||
}
|
||||
formatted.WriteString("\n")
|
||||
}
|
||||
}
|
||||
c.logger.Info("loaded GPU status", "type", "nvidia")
|
||||
resultChan <- gpuResult{content: formatted.String(), err: nil}
|
||||
return
|
||||
}
|
||||
|
||||
cmd = "system_profiler SPDisplaysDataType | grep 'Chipset Model\\|VRAM' | head -2"
|
||||
out, err = c.server.Exec(cmd)
|
||||
if err != nil {
|
||||
c.logger.Warn("GPU info unavailable", "error", err)
|
||||
resultChan <- gpuResult{
|
||||
content: "GPU info unavailable\n\nRun on a system with nvidia-smi or macOS GPU",
|
||||
err: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var formatted strings.Builder
|
||||
formatted.WriteString("GPU Status (macOS)\n")
|
||||
formatted.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
lines := strings.Split(strings.TrimSpace(out), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "Chipset Model") || strings.Contains(line, "VRAM") {
|
||||
formatted.WriteString("🎮 " + strings.TrimSpace(line) + "\n")
|
||||
c.logger.Info("loaded GPU status", "type", "nvml", "count", len(gpus))
|
||||
resultChan <- gpuResult{content: formatted.String(), err: nil}
|
||||
return
|
||||
}
|
||||
}
|
||||
formatted.WriteString("\n💡 Note: nvidia-smi not available on macOS\n")
|
||||
|
||||
c.logger.Info("loaded GPU status", "type", "macos")
|
||||
resultChan <- gpuResult{content: formatted.String(), err: nil}
|
||||
// Try macOS GPU monitoring (development mode on macOS)
|
||||
if worker.IsMacOS() {
|
||||
gpuStatus, err := worker.FormatMacOSGPUStatus()
|
||||
if err == nil && gpuStatus != "" {
|
||||
c.logger.Info("loaded GPU status", "type", "macos")
|
||||
resultChan <- gpuResult{content: gpuStatus, err: nil}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// No GPU monitoring available
|
||||
c.logger.Warn("GPU info unavailable", "platform", runtime.GOOS)
|
||||
resultChan <- gpuResult{
|
||||
content: "GPU info unavailable\n\nNVML: NVIDIA driver not installed or incompatible\nmacOS: system_profiler not available",
|
||||
err: fmt.Errorf("no GPU monitoring available on %s", runtime.GOOS),
|
||||
}
|
||||
}()
|
||||
|
||||
result := <-resultChan
|
||||
|
|
@ -362,6 +367,18 @@ func (c *Controller) showQueue(m model.State) tea.Cmd {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadDatasets() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
datasets, err := c.taskQueue.ListDatasets()
|
||||
if err != nil {
|
||||
c.logger.Error("failed to load datasets", "error", err)
|
||||
return model.StatusMsg{Text: "Failed to load datasets: " + err.Error(), Level: "error"}
|
||||
}
|
||||
c.logger.Info("loaded datasets", "count", len(datasets))
|
||||
return model.DatasetsLoadedMsg(datasets)
|
||||
}
|
||||
}
|
||||
|
||||
func tickCmd() tea.Cmd {
|
||||
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
|
||||
return model.TickMsg(t)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package controller
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
|
|
@ -19,6 +20,7 @@ type Controller struct {
|
|||
server *services.MLServer
|
||||
taskQueue *services.TaskQueue
|
||||
logger *logging.Logger
|
||||
wsClient *services.WebSocketClient
|
||||
}
|
||||
|
||||
func (c *Controller) handleKeyMsg(msg tea.KeyMsg, m model.State) (model.State, tea.Cmd) {
|
||||
|
|
@ -143,6 +145,33 @@ func (c *Controller) handleGlobalKeys(msg tea.KeyMsg, m *model.State) []tea.Cmd
|
|||
case key.Matches(msg, m.Keys.ViewExperiments):
|
||||
m.ActiveView = model.ViewModeExperiments
|
||||
cmds = append(cmds, c.loadExperiments())
|
||||
case key.Matches(msg, m.Keys.ViewNarrative):
|
||||
m.ActiveView = model.ViewModeNarrative
|
||||
if job := getSelectedJob(*m); job != nil {
|
||||
m.SelectedJob = *job
|
||||
}
|
||||
case key.Matches(msg, m.Keys.ViewTeam):
|
||||
m.ActiveView = model.ViewModeTeam
|
||||
case key.Matches(msg, m.Keys.ViewExperimentHistory):
|
||||
m.ActiveView = model.ViewModeExperimentHistory
|
||||
cmds = append(cmds, c.loadExperimentHistory())
|
||||
case key.Matches(msg, m.Keys.ViewConfig):
|
||||
m.ActiveView = model.ViewModeConfig
|
||||
cmds = append(cmds, c.loadConfig())
|
||||
case key.Matches(msg, m.Keys.ViewLogs):
|
||||
m.ActiveView = model.ViewModeLogs
|
||||
if job := getSelectedJob(*m); job != nil {
|
||||
cmds = append(cmds, c.loadLogs(job.Name))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.ViewExport):
|
||||
if job := getSelectedJob(*m); job != nil {
|
||||
cmds = append(cmds, c.exportJob(job.Name))
|
||||
}
|
||||
case key.Matches(msg, m.Keys.FilterTeam):
|
||||
m.InputMode = true
|
||||
m.Input.SetValue("@")
|
||||
m.Input.Focus()
|
||||
m.Status = "Filter by team member: @alice, @bob, @team-ml"
|
||||
case key.Matches(msg, m.Keys.Cancel):
|
||||
if job := getSelectedJob(*m); job != nil && job.TaskID != "" {
|
||||
cmds = append(cmds, c.cancelTask(job.TaskID))
|
||||
|
|
@ -181,8 +210,18 @@ func (c *Controller) applyWindowSize(msg tea.WindowSizeMsg, m model.State) model
|
|||
m.QueueView.Height = listHeight - 4
|
||||
m.SettingsView.Width = panelWidth
|
||||
m.SettingsView.Height = listHeight - 4
|
||||
m.NarrativeView.Width = panelWidth
|
||||
m.NarrativeView.Height = listHeight - 4
|
||||
m.TeamView.Width = panelWidth
|
||||
m.TeamView.Height = listHeight - 4
|
||||
m.ExperimentsView.Width = panelWidth
|
||||
m.ExperimentsView.Height = listHeight - 4
|
||||
m.ExperimentHistoryView.Width = panelWidth
|
||||
m.ExperimentHistoryView.Height = listHeight - 4
|
||||
m.ConfigView.Width = panelWidth
|
||||
m.ConfigView.Height = listHeight - 4
|
||||
m.LogsView.Width = panelWidth
|
||||
m.LogsView.Height = listHeight - 4
|
||||
|
||||
return m
|
||||
}
|
||||
|
|
@ -245,7 +284,25 @@ func (c *Controller) handleStatusMsg(msg model.StatusMsg, m model.State) (model.
|
|||
|
||||
func (c *Controller) handleTickMsg(msg model.TickMsg, m model.State) (model.State, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
if time.Since(m.LastRefresh) > 10*time.Second && !m.IsLoading {
|
||||
|
||||
// Calculate actual refresh rate
|
||||
now := time.Now()
|
||||
if !m.LastFrameTime.IsZero() {
|
||||
elapsed := now.Sub(m.LastFrameTime).Milliseconds()
|
||||
if elapsed > 0 {
|
||||
// Smooth the rate with simple averaging
|
||||
m.RefreshRate = (m.RefreshRate*float64(m.FrameCount) + float64(elapsed)) / float64(m.FrameCount+1)
|
||||
m.FrameCount++
|
||||
if m.FrameCount > 100 {
|
||||
m.FrameCount = 1
|
||||
m.RefreshRate = float64(elapsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.LastFrameTime = now
|
||||
|
||||
// 500ms refresh target for real-time updates
|
||||
if time.Since(m.LastRefresh) > 500*time.Millisecond && !m.IsLoading {
|
||||
m.LastRefresh = time.Now()
|
||||
cmds = append(cmds, c.loadAllData())
|
||||
}
|
||||
|
|
@ -290,16 +347,25 @@ func New(
|
|||
tq *services.TaskQueue,
|
||||
logger *logging.Logger,
|
||||
) *Controller {
|
||||
// Create WebSocket client for real-time updates
|
||||
wsClient := services.NewWebSocketClient(cfg.ServerURL, "", logger)
|
||||
|
||||
return &Controller{
|
||||
config: cfg,
|
||||
server: srv,
|
||||
taskQueue: tq,
|
||||
logger: logger,
|
||||
wsClient: wsClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the TUI and returns initial commands
|
||||
func (c *Controller) Init() tea.Cmd {
|
||||
// Connect WebSocket for real-time updates
|
||||
if err := c.wsClient.Connect(); err != nil {
|
||||
c.logger.Error("WebSocket connection failed", "error", err)
|
||||
}
|
||||
|
||||
return tea.Batch(
|
||||
tea.SetWindowTitle("FetchML"),
|
||||
c.loadAllData(),
|
||||
|
|
@ -307,14 +373,17 @@ func (c *Controller) Init() tea.Cmd {
|
|||
)
|
||||
}
|
||||
|
||||
// Update handles all messages and updates the state
|
||||
func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
||||
switch typed := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
return c.handleKeyMsg(typed, m)
|
||||
case tea.WindowSizeMsg:
|
||||
updated := c.applyWindowSize(typed, m)
|
||||
return c.finalizeUpdate(msg, updated)
|
||||
// Only apply window size on first render, then keep constant
|
||||
if m.Width == 0 && m.Height == 0 {
|
||||
updated := c.applyWindowSize(typed, m)
|
||||
return c.finalizeUpdate(msg, updated)
|
||||
}
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.JobsLoadedMsg:
|
||||
return c.handleJobsLoadedMsg(typed, m)
|
||||
case model.TasksLoadedMsg:
|
||||
|
|
@ -325,6 +394,26 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
|||
return c.handleContainerContent(typed, m)
|
||||
case model.QueueLoadedMsg:
|
||||
return c.handleQueueContent(typed, m)
|
||||
case model.DatasetsLoadedMsg:
|
||||
// Format datasets into view content
|
||||
var content strings.Builder
|
||||
content.WriteString("Available Datasets\n")
|
||||
content.WriteString(strings.Repeat("═", 50) + "\n\n")
|
||||
if len(typed) == 0 {
|
||||
content.WriteString("📭 No datasets found\n\n")
|
||||
content.WriteString("Datasets will appear here when available\n")
|
||||
content.WriteString("in the data directory.")
|
||||
} else {
|
||||
for i, ds := range typed {
|
||||
content.WriteString(fmt.Sprintf("%d. 📁 %s\n", i+1, ds.Name))
|
||||
content.WriteString(fmt.Sprintf(" Location: %s\n", ds.Location))
|
||||
content.WriteString(fmt.Sprintf(" Size: %d bytes\n", ds.SizeBytes))
|
||||
content.WriteString(fmt.Sprintf(" Last Access: %s\n\n", ds.LastAccess.Format("2006-01-02 15:04")))
|
||||
}
|
||||
}
|
||||
m.DatasetView.SetContent(content.String())
|
||||
m.DatasetView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.SettingsContentMsg:
|
||||
m.SettingsView.SetContent(string(typed))
|
||||
return c.finalizeUpdate(msg, m)
|
||||
|
|
@ -332,12 +421,36 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
|||
m.ExperimentsView.SetContent(string(typed))
|
||||
m.ExperimentsView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case ExperimentHistoryLoadedMsg:
|
||||
m.ExperimentHistoryView.SetContent(string(typed))
|
||||
m.ExperimentHistoryView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case ConfigLoadedMsg:
|
||||
m.ConfigView.SetContent(string(typed))
|
||||
m.ConfigView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case LogsLoadedMsg:
|
||||
m.LogsView.SetContent(string(typed))
|
||||
m.LogsView.GotoTop()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.SettingsUpdateMsg:
|
||||
return c.finalizeUpdate(msg, m)
|
||||
case model.StatusMsg:
|
||||
return c.handleStatusMsg(typed, m)
|
||||
case model.TickMsg:
|
||||
return c.handleTickMsg(typed, m)
|
||||
case model.JobUpdateMsg:
|
||||
// Handle real-time job status updates from WebSocket
|
||||
m.Status = fmt.Sprintf("Job %s: %s", typed.JobName, typed.Status)
|
||||
// Refresh job list to show updated status
|
||||
return m, c.loadAllData()
|
||||
case model.GPUUpdateMsg:
|
||||
// Throttle GPU updates to 1/second (humans can't perceive faster)
|
||||
if time.Since(m.LastGPUUpdate) > 1*time.Second {
|
||||
m.LastGPUUpdate = time.Now()
|
||||
return c.finalizeUpdate(msg, m)
|
||||
}
|
||||
return m, nil
|
||||
default:
|
||||
return c.finalizeUpdate(msg, m)
|
||||
}
|
||||
|
|
@ -346,6 +459,12 @@ func (c *Controller) Update(msg tea.Msg, m model.State) (model.State, tea.Cmd) {
|
|||
// ExperimentsLoadedMsg is sent when experiments are loaded
|
||||
type ExperimentsLoadedMsg string
|
||||
|
||||
// ExperimentHistoryLoadedMsg is sent when experiment history is loaded
|
||||
type ExperimentHistoryLoadedMsg string
|
||||
|
||||
// ConfigLoadedMsg is sent when config is loaded
|
||||
type ConfigLoadedMsg string
|
||||
|
||||
func (c *Controller) loadExperiments() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
commitIDs, err := c.taskQueue.ListExperiments()
|
||||
|
|
@ -372,3 +491,92 @@ func (c *Controller) loadExperiments() tea.Cmd {
|
|||
return ExperimentsLoadedMsg(output)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadExperimentHistory() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Placeholder - will show experiment history with annotations
|
||||
return ExperimentHistoryLoadedMsg("Experiment History & Annotations\n\n" +
|
||||
"This view will show:\n" +
|
||||
"- Previous experiment runs\n" +
|
||||
"- Annotations and notes\n" +
|
||||
"- Config snapshots\n" +
|
||||
"- Side-by-side comparisons\n\n" +
|
||||
"(Requires API: GET /api/experiments/:id/history)")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) loadConfig() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Build config diff showing changes from defaults
|
||||
var output strings.Builder
|
||||
output.WriteString("⚙️ Config View (Read-Only)\n\n")
|
||||
|
||||
output.WriteString("┌─ Changes from Defaults ─────────────────────┐\n")
|
||||
changes := []string{}
|
||||
|
||||
if c.config.Host != "" {
|
||||
changes = append(changes, fmt.Sprintf("│ Host: %s", c.config.Host))
|
||||
}
|
||||
if c.config.Port != 0 && c.config.Port != 22 {
|
||||
changes = append(changes, fmt.Sprintf("│ Port: %d (default: 22)", c.config.Port))
|
||||
}
|
||||
if c.config.BasePath != "" {
|
||||
changes = append(changes, fmt.Sprintf("│ Base Path: %s", c.config.BasePath))
|
||||
}
|
||||
if c.config.RedisAddr != "" && c.config.RedisAddr != "localhost:6379" {
|
||||
changes = append(changes, fmt.Sprintf("│ Redis: %s (default: localhost:6379)", c.config.RedisAddr))
|
||||
}
|
||||
if c.config.ServerURL != "" {
|
||||
changes = append(changes, fmt.Sprintf("│ Server: %s", c.config.ServerURL))
|
||||
}
|
||||
|
||||
if len(changes) == 0 {
|
||||
output.WriteString("│ (Using all default settings)\n")
|
||||
} else {
|
||||
for _, change := range changes {
|
||||
output.WriteString(change + "\n")
|
||||
}
|
||||
}
|
||||
output.WriteString("└─────────────────────────────────────────────┘\n\n")
|
||||
|
||||
output.WriteString("Full Configuration:\n")
|
||||
output.WriteString(fmt.Sprintf(" Host: %s\n", c.config.Host))
|
||||
output.WriteString(fmt.Sprintf(" Port: %d\n", c.config.Port))
|
||||
output.WriteString(fmt.Sprintf(" Base Path: %s\n", c.config.BasePath))
|
||||
output.WriteString(fmt.Sprintf(" Redis: %s\n", c.config.RedisAddr))
|
||||
output.WriteString(fmt.Sprintf(" Server: %s\n", c.config.ServerURL))
|
||||
output.WriteString(fmt.Sprintf(" User: %s\n\n", c.config.User))
|
||||
|
||||
output.WriteString("Use CLI to modify: ml config set <key> <value>")
|
||||
|
||||
return ConfigLoadedMsg(output.String())
|
||||
}
|
||||
}
|
||||
|
||||
// LogsLoadedMsg is sent when logs are loaded
|
||||
type LogsLoadedMsg string
|
||||
|
||||
func (c *Controller) loadLogs(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Placeholder - will stream logs from job
|
||||
return LogsLoadedMsg("📜 Logs for " + jobName + "\n\n" +
|
||||
"Log streaming will appear here...\n\n" +
|
||||
"(Requires API: GET /api/jobs/" + jobName + "/logs?follow=true)")
|
||||
}
|
||||
}
|
||||
|
||||
// ExportCompletedMsg is sent when export is complete
|
||||
type ExportCompletedMsg struct {
|
||||
JobName string
|
||||
Path string
|
||||
}
|
||||
|
||||
func (c *Controller) exportJob(jobName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
// Show export in progress
|
||||
return model.StatusMsg{
|
||||
Text: "Exporting " + jobName + "... (anonymized)",
|
||||
Level: "info",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
// Package model provides TUI data structures and state management
|
||||
package model
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// JobStatus represents the status of a job
|
||||
type JobStatus string
|
||||
|
|
@ -21,12 +25,23 @@ type Job struct {
|
|||
Status JobStatus
|
||||
TaskID string
|
||||
Priority int64
|
||||
// Narrative fields for research context
|
||||
Hypothesis string
|
||||
Context string
|
||||
Intent string
|
||||
ExpectedOutcome string
|
||||
ActualOutcome string
|
||||
OutcomeStatus string // validated, invalidated, inconclusive
|
||||
// GPU allocation tracking
|
||||
GPUDeviceID int // -1 if not assigned
|
||||
GPUUtilization int // 0-100%
|
||||
GPUMemoryUsed int64 // MB
|
||||
}
|
||||
|
||||
// Title returns the job title for display
|
||||
func (j Job) Title() string { return j.Name }
|
||||
|
||||
// Description returns a formatted description with status icon
|
||||
// Description returns a formatted description with status icon and GPU info
|
||||
func (j Job) Description() string {
|
||||
icon := map[JobStatus]string{
|
||||
StatusPending: "⏸",
|
||||
|
|
@ -39,7 +54,16 @@ func (j Job) Description() string {
|
|||
if j.Priority > 0 {
|
||||
pri = fmt.Sprintf(" [P%d]", j.Priority)
|
||||
}
|
||||
return fmt.Sprintf("%s %s%s", icon, j.Status, pri)
|
||||
gpu := ""
|
||||
if j.GPUDeviceID >= 0 {
|
||||
gpu = fmt.Sprintf(" [GPU:%d %d%%]", j.GPUDeviceID, j.GPUUtilization)
|
||||
}
|
||||
|
||||
// Apply status color to the status text
|
||||
statusStyle := lipgloss.NewStyle().Foreground(StatusColor(j.Status))
|
||||
coloredStatus := statusStyle.Render(string(j.Status))
|
||||
|
||||
return fmt.Sprintf("%s %s%s%s", icon, coloredStatus, pri, gpu)
|
||||
}
|
||||
|
||||
// FilterValue returns the value used for filtering
|
||||
|
|
|
|||
|
|
@ -5,42 +5,56 @@ import "github.com/charmbracelet/bubbles/key"
|
|||
|
||||
// KeyMap defines key bindings for the TUI
|
||||
type KeyMap struct {
|
||||
Refresh key.Binding
|
||||
Trigger key.Binding
|
||||
TriggerArgs key.Binding
|
||||
ViewQueue key.Binding
|
||||
ViewContainer key.Binding
|
||||
ViewGPU key.Binding
|
||||
ViewJobs key.Binding
|
||||
ViewDatasets key.Binding
|
||||
ViewExperiments key.Binding
|
||||
ViewSettings key.Binding
|
||||
Cancel key.Binding
|
||||
Delete key.Binding
|
||||
MarkFailed key.Binding
|
||||
RefreshGPU key.Binding
|
||||
Help key.Binding
|
||||
Quit key.Binding
|
||||
Refresh key.Binding
|
||||
Trigger key.Binding
|
||||
TriggerArgs key.Binding
|
||||
ViewQueue key.Binding
|
||||
ViewContainer key.Binding
|
||||
ViewGPU key.Binding
|
||||
ViewJobs key.Binding
|
||||
ViewDatasets key.Binding
|
||||
ViewExperiments key.Binding
|
||||
ViewSettings key.Binding
|
||||
ViewNarrative key.Binding
|
||||
ViewTeam key.Binding
|
||||
ViewExperimentHistory key.Binding
|
||||
ViewConfig key.Binding
|
||||
ViewLogs key.Binding
|
||||
ViewExport key.Binding
|
||||
FilterTeam key.Binding
|
||||
Cancel key.Binding
|
||||
Delete key.Binding
|
||||
MarkFailed key.Binding
|
||||
RefreshGPU key.Binding
|
||||
Help key.Binding
|
||||
Quit key.Binding
|
||||
}
|
||||
|
||||
// DefaultKeys returns the default key bindings for the TUI
|
||||
func DefaultKeys() KeyMap {
|
||||
return KeyMap{
|
||||
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
|
||||
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
|
||||
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
|
||||
ViewQueue: key.NewBinding(key.WithKeys("v"), key.WithHelp("v", "view queue")),
|
||||
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
|
||||
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
|
||||
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
|
||||
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
|
||||
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
|
||||
Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel task")),
|
||||
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
|
||||
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
|
||||
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
|
||||
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
|
||||
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
|
||||
Quit: key.NewBinding(key.WithKeys("q", "ctrl+c"), key.WithHelp("q", "quit")),
|
||||
Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh all")),
|
||||
Trigger: key.NewBinding(key.WithKeys("t"), key.WithHelp("t", "queue job")),
|
||||
TriggerArgs: key.NewBinding(key.WithKeys("a"), key.WithHelp("a", "queue w/ args")),
|
||||
ViewQueue: key.NewBinding(key.WithKeys("q"), key.WithHelp("q", "view queue")),
|
||||
ViewContainer: key.NewBinding(key.WithKeys("o"), key.WithHelp("o", "containers")),
|
||||
ViewGPU: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "gpu status")),
|
||||
ViewJobs: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "job list")),
|
||||
ViewDatasets: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "datasets")),
|
||||
ViewExperiments: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "experiments")),
|
||||
ViewNarrative: key.NewBinding(key.WithKeys("n"), key.WithHelp("n", "narrative")),
|
||||
ViewTeam: key.NewBinding(key.WithKeys("m"), key.WithHelp("m", "team")),
|
||||
ViewExperimentHistory: key.NewBinding(key.WithKeys("e"), key.WithHelp("e", "experiment history")),
|
||||
ViewConfig: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "config")),
|
||||
ViewSettings: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "settings")),
|
||||
ViewLogs: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "logs")),
|
||||
ViewExport: key.NewBinding(key.WithKeys("E"), key.WithHelp("E", "export job")),
|
||||
FilterTeam: key.NewBinding(key.WithKeys("@"), key.WithHelp("@", "filter by team")),
|
||||
Cancel: key.NewBinding(key.WithKeys("x"), key.WithHelp("x", "cancel task")),
|
||||
Delete: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "delete job")),
|
||||
MarkFailed: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "mark failed")),
|
||||
RefreshGPU: key.NewBinding(key.WithKeys("G"), key.WithHelp("G", "refresh GPU")),
|
||||
Help: key.NewBinding(key.WithKeys("h", "?"), key.WithHelp("h/?", "toggle help")),
|
||||
Quit: key.NewBinding(key.WithKeys("ctrl+c"), key.WithHelp("ctrl+c", "quit")),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ type JobsLoadedMsg []Job
|
|||
// TasksLoadedMsg contains loaded tasks from the queue
|
||||
type TasksLoadedMsg []*Task
|
||||
|
||||
// DatasetsLoadedMsg contains loaded datasets
|
||||
type DatasetsLoadedMsg []DatasetInfo
|
||||
|
||||
// GpuLoadedMsg contains GPU status information
|
||||
type GpuLoadedMsg string
|
||||
|
||||
|
|
|
|||
|
|
@ -32,13 +32,18 @@ type ViewMode int
|
|||
|
||||
// ViewMode constants represent different TUI views
|
||||
const (
|
||||
ViewModeJobs ViewMode = iota // Jobs view mode
|
||||
ViewModeGPU // GPU status view mode
|
||||
ViewModeQueue // Queue status view mode
|
||||
ViewModeContainer // Container status view mode
|
||||
ViewModeSettings // Settings view mode
|
||||
ViewModeDatasets // Datasets view mode
|
||||
ViewModeExperiments // Experiments view mode
|
||||
ViewModeJobs ViewMode = iota // Jobs view mode
|
||||
ViewModeGPU // GPU status view mode
|
||||
ViewModeQueue // Queue status view mode
|
||||
ViewModeContainer // Container status view mode
|
||||
ViewModeSettings // Settings view mode
|
||||
ViewModeDatasets // Datasets view mode
|
||||
ViewModeExperiments // Experiments view mode
|
||||
ViewModeNarrative // Narrative/Outcome view mode
|
||||
ViewModeTeam // Team collaboration view mode
|
||||
ViewModeExperimentHistory // Experiment history view mode
|
||||
ViewModeConfig // Config view mode
|
||||
ViewModeLogs // Logs streaming view mode
|
||||
)
|
||||
|
||||
// DatasetInfo represents dataset information in the TUI
|
||||
|
|
@ -51,32 +56,42 @@ type DatasetInfo struct {
|
|||
|
||||
// State holds the application state
|
||||
type State struct {
|
||||
Jobs []Job
|
||||
QueuedTasks []*Task
|
||||
Datasets []DatasetInfo
|
||||
JobList list.Model
|
||||
GpuView viewport.Model
|
||||
ContainerView viewport.Model
|
||||
QueueView viewport.Model
|
||||
SettingsView viewport.Model
|
||||
DatasetView viewport.Model
|
||||
ExperimentsView viewport.Model
|
||||
Input textinput.Model
|
||||
APIKeyInput textinput.Model
|
||||
Status string
|
||||
ErrorMsg string
|
||||
InputMode bool
|
||||
Width int
|
||||
Height int
|
||||
ShowHelp bool
|
||||
Spinner spinner.Model
|
||||
ActiveView ViewMode
|
||||
LastRefresh time.Time
|
||||
IsLoading bool
|
||||
JobStats map[JobStatus]int
|
||||
APIKey string
|
||||
SettingsIndex int
|
||||
Keys KeyMap
|
||||
Jobs []Job
|
||||
QueuedTasks []*Task
|
||||
Datasets []DatasetInfo
|
||||
JobList list.Model
|
||||
GpuView viewport.Model
|
||||
ContainerView viewport.Model
|
||||
QueueView viewport.Model
|
||||
SettingsView viewport.Model
|
||||
DatasetView viewport.Model
|
||||
ExperimentsView viewport.Model
|
||||
NarrativeView viewport.Model
|
||||
TeamView viewport.Model
|
||||
ExperimentHistoryView viewport.Model
|
||||
ConfigView viewport.Model
|
||||
LogsView viewport.Model
|
||||
SelectedJob Job
|
||||
Input textinput.Model
|
||||
APIKeyInput textinput.Model
|
||||
Status string
|
||||
ErrorMsg string
|
||||
InputMode bool
|
||||
Width int
|
||||
Height int
|
||||
ShowHelp bool
|
||||
Spinner spinner.Model
|
||||
ActiveView ViewMode
|
||||
LastRefresh time.Time
|
||||
LastFrameTime time.Time
|
||||
RefreshRate float64 // measured in ms
|
||||
FrameCount int
|
||||
LastGPUUpdate time.Time
|
||||
IsLoading bool
|
||||
JobStats map[JobStatus]int
|
||||
APIKey string
|
||||
SettingsIndex int
|
||||
Keys KeyMap
|
||||
}
|
||||
|
||||
// InitialState creates the initial application state
|
||||
|
|
@ -105,25 +120,54 @@ func InitialState(apiKey string) State {
|
|||
s.Style = SpinnerStyle()
|
||||
|
||||
return State{
|
||||
JobList: jobList,
|
||||
GpuView: viewport.New(0, 0),
|
||||
ContainerView: viewport.New(0, 0),
|
||||
QueueView: viewport.New(0, 0),
|
||||
SettingsView: viewport.New(0, 0),
|
||||
DatasetView: viewport.New(0, 0),
|
||||
ExperimentsView: viewport.New(0, 0),
|
||||
Input: input,
|
||||
APIKeyInput: apiKeyInput,
|
||||
Status: "Connected",
|
||||
InputMode: false,
|
||||
ShowHelp: false,
|
||||
Spinner: s,
|
||||
ActiveView: ViewModeJobs,
|
||||
LastRefresh: time.Now(),
|
||||
IsLoading: false,
|
||||
JobStats: make(map[JobStatus]int),
|
||||
APIKey: apiKey,
|
||||
SettingsIndex: 0,
|
||||
Keys: DefaultKeys(),
|
||||
JobList: jobList,
|
||||
GpuView: viewport.New(0, 0),
|
||||
ContainerView: viewport.New(0, 0),
|
||||
QueueView: viewport.New(0, 0),
|
||||
SettingsView: viewport.New(0, 0),
|
||||
DatasetView: viewport.New(0, 0),
|
||||
ExperimentsView: viewport.New(0, 0),
|
||||
NarrativeView: viewport.New(0, 0),
|
||||
TeamView: viewport.New(0, 0),
|
||||
ExperimentHistoryView: viewport.New(0, 0),
|
||||
ConfigView: viewport.New(0, 0),
|
||||
LogsView: viewport.New(0, 0),
|
||||
Input: input,
|
||||
APIKeyInput: apiKeyInput,
|
||||
Status: "Connected",
|
||||
InputMode: false,
|
||||
ShowHelp: false,
|
||||
Spinner: s,
|
||||
ActiveView: ViewModeJobs,
|
||||
LastRefresh: time.Now(),
|
||||
IsLoading: false,
|
||||
JobStats: make(map[JobStatus]int),
|
||||
APIKey: apiKey,
|
||||
SettingsIndex: 0,
|
||||
Keys: DefaultKeys(),
|
||||
}
|
||||
}
|
||||
|
||||
// LogMsg represents a log line from a job
|
||||
type LogMsg struct {
|
||||
JobName string `json:"job_name"`
|
||||
Line string `json:"line"`
|
||||
Time string `json:"time"`
|
||||
}
|
||||
|
||||
// JobUpdateMsg represents a real-time job status update via WebSocket
|
||||
type JobUpdateMsg struct {
|
||||
JobName string `json:"job_name"`
|
||||
Status string `json:"status"`
|
||||
TaskID string `json:"task_id"`
|
||||
Progress int `json:"progress"`
|
||||
}
|
||||
|
||||
// GPUUpdateMsg represents a real-time GPU status update via WebSocket
|
||||
type GPUUpdateMsg struct {
|
||||
DeviceID int `json:"device_id"`
|
||||
Utilization int `json:"utilization"`
|
||||
MemoryUsed int64 `json:"memory_used"`
|
||||
MemoryTotal int64 `json:"memory_total"`
|
||||
Temperature int `json:"temperature"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,38 @@ import (
|
|||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Status colors for job list items
|
||||
var (
|
||||
// StatusRunningColor is green for running jobs
|
||||
StatusRunningColor = lipgloss.Color("#2ecc71")
|
||||
// StatusPendingColor is yellow for pending jobs
|
||||
StatusPendingColor = lipgloss.Color("#f1c40f")
|
||||
// StatusFailedColor is red for failed jobs
|
||||
StatusFailedColor = lipgloss.Color("#e74c3c")
|
||||
// StatusFinishedColor is blue for completed jobs
|
||||
StatusFinishedColor = lipgloss.Color("#3498db")
|
||||
// StatusQueuedColor is gray for queued jobs
|
||||
StatusQueuedColor = lipgloss.Color("#95a5a6")
|
||||
)
|
||||
|
||||
// StatusColor returns the color for a job status
|
||||
func StatusColor(status JobStatus) lipgloss.Color {
|
||||
switch status {
|
||||
case StatusRunning:
|
||||
return StatusRunningColor
|
||||
case StatusPending:
|
||||
return StatusPendingColor
|
||||
case StatusFailed:
|
||||
return StatusFailedColor
|
||||
case StatusFinished:
|
||||
return StatusFinishedColor
|
||||
case StatusQueued:
|
||||
return StatusQueuedColor
|
||||
default:
|
||||
return lipgloss.Color("#ffffff")
|
||||
}
|
||||
}
|
||||
|
||||
// NewJobListDelegate creates a styled delegate for the job list
|
||||
func NewJobListDelegate() list.DefaultDelegate {
|
||||
delegate := list.NewDefaultDelegate()
|
||||
|
|
|
|||
46
cmd/tui/internal/services/export.go
Normal file
46
cmd/tui/internal/services/export.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Package services provides TUI service clients
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/logging"
|
||||
)
|
||||
|
||||
// ExportService handles job export functionality for TUI
|
||||
type ExportService struct {
|
||||
serverURL string
|
||||
apiKey string
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewExportService creates a new export service
|
||||
func NewExportService(serverURL, apiKey string, logger *logging.Logger) *ExportService {
|
||||
return &ExportService{
|
||||
serverURL: serverURL,
|
||||
apiKey: apiKey,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExportJob exports a job with optional anonymization
|
||||
// Returns the path to the exported file
|
||||
func (s *ExportService) ExportJob(jobName string, anonymize bool) (string, error) {
|
||||
s.logger.Info("exporting job", "job", jobName, "anonymize", anonymize)
|
||||
|
||||
// Placeholder - actual implementation would call API
|
||||
// POST /api/jobs/{id}/export?anonymize=true
|
||||
|
||||
exportPath := fmt.Sprintf("/tmp/%s_export_%d.tar.gz", jobName, time.Now().Unix())
|
||||
|
||||
s.logger.Info("export complete", "job", jobName, "path", exportPath)
|
||||
return exportPath, nil
|
||||
}
|
||||
|
||||
// ExportOptions contains options for export
|
||||
type ExportOptions struct {
|
||||
Anonymize bool
|
||||
IncludeLogs bool
|
||||
IncludeData bool
|
||||
}
|
||||
|
|
@ -4,8 +4,12 @@ package services
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/config"
|
||||
"github.com/jfraeys/fetch_ml/cmd/tui/internal/model"
|
||||
"github.com/jfraeys/fetch_ml/internal/domain"
|
||||
"github.com/jfraeys/fetch_ml/internal/experiment"
|
||||
"github.com/jfraeys/fetch_ml/internal/network"
|
||||
|
|
@ -21,6 +25,7 @@ type TaskQueue struct {
|
|||
*queue.TaskQueue // Embed to inherit all queue methods directly
|
||||
expManager *experiment.Manager
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewTaskQueue creates a new task queue service
|
||||
|
|
@ -37,14 +42,17 @@ func NewTaskQueue(cfg *config.Config) (*TaskQueue, error) {
|
|||
return nil, fmt.Errorf("failed to create task queue: %w", err)
|
||||
}
|
||||
|
||||
// Initialize experiment manager
|
||||
// TODO: Get base path from config
|
||||
expManager := experiment.NewManager("./experiments")
|
||||
// Initialize experiment manager with proper path
|
||||
// BasePath already includes the mode-based experiments path (e.g., ./data/dev/experiments)
|
||||
expDir := cfg.BasePath
|
||||
os.MkdirAll(expDir, 0755)
|
||||
expManager := experiment.NewManager(expDir)
|
||||
|
||||
return &TaskQueue{
|
||||
TaskQueue: internalQueue,
|
||||
expManager: expManager,
|
||||
ctx: context.Background(),
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -94,20 +102,38 @@ func (tq *TaskQueue) GetMetrics(_ string) (map[string]string, error) {
|
|||
return map[string]string{}, nil
|
||||
}
|
||||
|
||||
// ListDatasets retrieves available datasets (TUI-specific: currently returns empty)
|
||||
func (tq *TaskQueue) ListDatasets() ([]struct {
|
||||
Name string
|
||||
SizeBytes int64
|
||||
Location string
|
||||
LastAccess string
|
||||
}, error) {
|
||||
// This method doesn't exist in internal queue, return empty for now
|
||||
return []struct {
|
||||
Name string
|
||||
SizeBytes int64
|
||||
Location string
|
||||
LastAccess string
|
||||
}{}, nil
|
||||
// ListDatasets retrieves available datasets from the filesystem
|
||||
func (tq *TaskQueue) ListDatasets() ([]model.DatasetInfo, error) {
|
||||
var datasets []model.DatasetInfo
|
||||
|
||||
// Scan the active data directory for datasets
|
||||
dataDir := tq.config.BasePath
|
||||
if dataDir == "" {
|
||||
return datasets, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dataDir)
|
||||
if err != nil {
|
||||
// Directory might not exist yet, return empty
|
||||
return datasets, nil
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
datasets = append(datasets, model.DatasetInfo{
|
||||
Name: entry.Name(),
|
||||
SizeBytes: info.Size(),
|
||||
Location: filepath.Join(dataDir, entry.Name()),
|
||||
LastAccess: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return datasets, nil
|
||||
}
|
||||
|
||||
// ListExperiments retrieves experiment list
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue