diff --git a/cli/src/commands/privacy.zig b/cli/src/commands/privacy.zig new file mode 100644 index 0000000..e25e8bb --- /dev/null +++ b/cli/src/commands/privacy.zig @@ -0,0 +1,241 @@ +const std = @import("std"); +const colors = @import("../utils/colors.zig"); +const Config = @import("../config.zig").Config; +const crypto = @import("../utils/crypto.zig"); +const io = @import("../utils/io.zig"); +const ws = @import("../net/ws/client.zig"); +const protocol = @import("../net/protocol.zig"); +const manifest = @import("../utils/manifest.zig"); + +pub fn run(allocator: std.mem.Allocator, argv: []const []const u8) !void { + if (argv.len == 0) { + try printUsage(); + return error.InvalidArgs; + } + + const sub = argv[0]; + if (std.mem.eql(u8, sub, "--help") or std.mem.eql(u8, sub, "-h")) { + try printUsage(); + return; + } + + if (!std.mem.eql(u8, sub, "set")) { + colors.printError("Unknown subcommand: {s}\n", .{sub}); + try printUsage(); + return error.InvalidArgs; + } + + if (argv.len < 2) { + try printUsage(); + return error.InvalidArgs; + } + + const target = argv[1]; + + var privacy_level: ?[]const u8 = null; + var team: ?[]const u8 = null; + var owner: ?[]const u8 = null; + var base_override: ?[]const u8 = null; + var json_mode: bool = false; + + var i: usize = 2; + while (i < argv.len) : (i += 1) { + const a = argv[i]; + if (std.mem.eql(u8, a, "--level") or std.mem.eql(u8, a, "--privacy-level")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + privacy_level = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--team")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + team = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--owner")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + owner = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--base")) { + if (i + 1 >= argv.len) return error.InvalidArgs; + base_override = argv[i + 1]; + i += 1; + } else if (std.mem.eql(u8, a, "--json")) { + json_mode = true; + } else if (std.mem.eql(u8, a, "--help") or std.mem.eql(u8, a, "-h")) { + try printUsage(); + return; + } else { + colors.printError("Unknown option: {s}\n", .{a}); + return error.InvalidArgs; + } + } + + if (privacy_level == null and team == null and owner == null) { + colors.printError("No privacy fields provided.\n", .{}); + return error.InvalidArgs; + } + + // Validate privacy level if provided + if (privacy_level) |pl| { + const valid = std.mem.eql(u8, pl, "private") or + std.mem.eql(u8, pl, "team") or + std.mem.eql(u8, pl, "public") or + std.mem.eql(u8, pl, "anonymized"); + if (!valid) { + colors.printError("Invalid privacy level: {s}. Must be one of: private, team, public, anonymized\n", .{pl}); + return error.InvalidArgs; + } + } + + const cfg = try Config.load(allocator); + defer { + var mut_cfg = cfg; + mut_cfg.deinit(allocator); + } + + const resolved_base = base_override orelse cfg.worker_base; + const manifest_path = manifest.resolvePathWithBase(allocator, target, resolved_base) catch |err| { + if (err == error.FileNotFound) { + colors.printError( + "Could not locate run_manifest.json for '{s}'. Provide a path, or use --base to scan finished/failed/running/pending.\n", + .{target}, + ); + } + return err; + }; + defer allocator.free(manifest_path); + + const job_name = try manifest.readJobNameFromManifest(allocator, manifest_path); + defer allocator.free(job_name); + + const patch_json = try buildPrivacyPatchJSON( + allocator, + privacy_level, + team, + owner, + ); + defer allocator.free(patch_json); + + const api_key_hash = try crypto.hashApiKey(allocator, cfg.api_key); + defer allocator.free(api_key_hash); + + const ws_url = try cfg.getWebSocketUrl(allocator); + defer allocator.free(ws_url); + + var client = try ws.Client.connect(allocator, ws_url, cfg.api_key); + defer client.close(); + + try client.sendSetRunPrivacy(job_name, patch_json, api_key_hash); + + if (json_mode) { + const msg = try client.receiveMessage(allocator); + defer allocator.free(msg); + + const packet = protocol.ResponsePacket.deserialize(msg, allocator) catch { + var out = io.stdoutWriter(); + try out.print("{s}\n", .{msg}); + return error.InvalidPacket; + }; + defer packet.deinit(allocator); + + if (packet.packet_type == .success) { + var out = io.stdoutWriter(); + try out.print("{{\"success\":true,\"job_name\":\"{s}\"}}\n", .{job_name}); + } else if (packet.packet_type == .error_packet) { + var out = io.stdoutWriter(); + try out.print("{{\"success\":false,\"error\":\"{s}\"}}\n", .{packet.error_message orelse "unknown"}); + } + } else { + try client.receiveAndHandleResponse(allocator, "Privacy set"); + } +} + +fn printUsage() !void { + colors.printInfo("Usage: ml privacy set [options]\n", .{}); + colors.printInfo("\nPrivacy Levels:\n", .{}); + colors.printInfo(" private Owner only (default)\n", .{}); + colors.printInfo(" team Same-team members can view\n", .{}); + colors.printInfo(" public All authenticated users\n", .{}); + colors.printInfo(" anonymized Strip PII before sharing\n", .{}); + colors.printInfo("\nOptions:\n", .{}); + colors.printInfo(" --level Set privacy level\n", .{}); + colors.printInfo(" --team Set team name\n", .{}); + colors.printInfo(" --owner Set owner email\n", .{}); + colors.printInfo(" --base Base path to search for run_manifest.json\n", .{}); + colors.printInfo(" --json Output JSON response\n", .{}); + colors.printInfo(" --help, -h Show this help\n", .{}); + colors.printInfo("\nExamples:\n", .{}); + colors.printInfo(" ml privacy set run_abc --level team --team vision-research\n", .{}); + colors.printInfo(" ml privacy set run_abc --level private\n", .{}); + colors.printInfo(" ml privacy set run_abc --owner user@lab.edu --team ml-group\n", .{}); +} + +fn buildPrivacyPatchJSON( + allocator: std.mem.Allocator, + privacy_level: ?[]const u8, + team: ?[]const u8, + owner: ?[]const u8, +) ![]u8 { + var buf = std.ArrayList(u8).empty; + defer buf.deinit(allocator); + + const writer = buf.writer(allocator); + try writer.writeAll("{\"privacy\":"); + try writer.writeAll("{"); + + var first = true; + + if (privacy_level) |pl| { + if (!first) try writer.writeAll(","); + first = false; + try writer.print("\"level\":\"{s}\"", .{pl}); + } + + if (team) |t| { + if (!first) try writer.writeAll(","); + first = false; + try writer.writeAll("\"team\":"); + try writeJSONString(writer, t); + } + + if (owner) |o| { + if (!first) try writer.writeAll(","); + first = false; + try writer.writeAll("\"owner\":"); + try writeJSONString(writer, o); + } + + try writer.writeAll("}}"); + + return buf.toOwnedSlice(allocator); +} + +fn writeJSONString(writer: anytype, s: []const u8) !void { + try writer.writeAll("\""); + for (s) |c| { + switch (c) { + '"' => try writer.writeAll("\\\""), + '\\' => try writer.writeAll("\\\\"), + '\n' => try writer.writeAll("\\n"), + '\r' => try writer.writeAll("\\r"), + '\t' => try writer.writeAll("\\t"), + else => { + if (c < 0x20) { + var buf: [6]u8 = undefined; + buf[0] = '\\'; + buf[1] = 'u'; + buf[2] = '0'; + buf[3] = '0'; + buf[4] = hexDigit(@intCast((c >> 4) & 0x0F)); + buf[5] = hexDigit(@intCast(c & 0x0F)); + try writer.writeAll(&buf); + } else { + try writer.writeAll(&[_]u8{c}); + } + }, + } + } + try writer.writeAll("\""); +} + +fn hexDigit(v: u8) u8 { + return if (v < 10) ('0' + v) else ('a' + (v - 10)); +} diff --git a/cli/src/utils/pii.zig b/cli/src/utils/pii.zig new file mode 100644 index 0000000..8a88a51 --- /dev/null +++ b/cli/src/utils/pii.zig @@ -0,0 +1,278 @@ +const std = @import("std"); + +/// PII detection patterns for research data privacy +pub const PIIPatterns = struct { + email: std.regex.Regex, + ssn: std.regex.Regex, + phone: std.regex.Regex, + credit_card: std.regex.Regex, + ip_address: std.regex.Regex, + + pub fn init(allocator: std.mem.Allocator) !PIIPatterns { + return PIIPatterns{ + .email = try std.regex.Regex.compile(allocator, "\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b"), + .ssn = try std.regex.Regex.compile(allocator, "\\b\\d{3}-\\d{2}-\\d{4}\\b"), + .phone = try std.regex.Regex.compile(allocator, "\\b\\d{3}-\\d{3}-\\d{4}\\b"), + .credit_card = try std.regex.Regex.compile(allocator, "\\b(?:\\d[ -]*?){13,16}\\b"), + .ip_address = try std.regex.Regex.compile(allocator, "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b"), + }; + } + + pub fn deinit(self: *PIIPatterns) void { + self.email.deinit(); + self.ssn.deinit(); + self.phone.deinit(); + self.credit_card.deinit(); + self.ip_address.deinit(); + } +}; + +/// A single PII finding +pub const PIIFinding = struct { + pii_type: []const u8, + start_pos: usize, + end_pos: usize, + matched_text: []const u8, + + pub fn format(self: PIIFinding, allocator: std.mem.Allocator) ![]u8 { + return std.fmt.allocPrint(allocator, "{s} at position {d}: '{s}'", .{ + self.pii_type, self.start_pos, self.matched_text, + }); + } +}; + +/// Detect PII in text - simplified version without regex +pub fn detectPIISimple(text: []const u8, allocator: std.mem.Allocator) ![]PIIFinding { + var findings = std.ArrayList(PIIFinding).initCapacity(allocator, 10) catch |err| { + return err; + }; + defer findings.deinit(allocator); + + // Check for email patterns (@ symbol with surrounding text) + var i: usize = 0; + while (i < text.len) : (i += 1) { + if (text[i] == '@') { + // Look backwards for email start + var start = i; + while (start > 0 and isEmailChar(text[start - 1])) { + start -= 1; + } + + // Look forwards for email end + var end = i + 1; + while (end < text.len and isEmailChar(text[end])) { + end += 1; + } + + // Check if it looks like an email (has . after @) + if (i > start and end > i + 1) { + var has_dot = false; + for (text[i + 1 .. end]) |c| { + if (c == '.') { + has_dot = true; + break; + } + } + if (has_dot) { + try findings.append(allocator, PIIFinding{ + .pii_type = "email", + .start_pos = start, + .end_pos = end, + .matched_text = text[start..end], + }); + } + } + } + } + + // Check for IP addresses (simple pattern: XXX.XXX.XXX.XXX) + i = 0; + while (i < text.len) : (i += 1) { + if (std.ascii.isDigit(text[i])) { + var end = i; + var dot_count: u8 = 0; + var digit_count: u8 = 0; + + while (end < text.len and (std.ascii.isDigit(text[end]) or text[end] == '.')) { + if (text[end] == '.') { + dot_count += 1; + digit_count = 0; + } else { + digit_count += 1; + if (digit_count > 3) break; + } + if (dot_count > 3) break; + end += 1; + } + + // Check if it looks like an IP address (xxx.xxx.xxx.xxx pattern) + if (dot_count == 3 and end - i >= 7 and end - i <= 15) { + var valid = true; + var num_start = i; + var num_idx: u8 = 0; + var nums: [4]u32 = undefined; + + var idx: usize = 0; + while (idx < end - i) : (idx += 1) { + const c = text[i + idx]; + if (c == '.') { + const num_str = text[num_start .. i + idx]; + nums[num_idx] = std.fmt.parseInt(u32, num_str, 10) catch { + valid = false; + break; + }; + if (nums[num_idx] > 255) { + valid = false; + break; + } + num_idx += 1; + num_start = i + idx + 1; + } + } + + // Parse last number + if (valid and num_idx == 3) { + const num_str = text[num_start..end]; + if (std.fmt.parseInt(u32, num_str, 10)) |parsed_num| { + nums[num_idx] = parsed_num; + if (valid and nums[num_idx] <= 255) { + try findings.append(allocator, PIIFinding{ + .pii_type = "ip_address", + .start_pos = i, + .end_pos = end, + .matched_text = text[i..end], + }); + } + } else |_| { + valid = false; + } + } + } + } + } + + return findings.toOwnedSlice(allocator); +} + +fn isEmailChar(c: u8) bool { + return std.ascii.isAlphanumeric(c) or c == '.' or c == '_' or c == '%' or c == '+' or c == '-'; +} + +/// Scan text and return warning if PII detected +pub fn scanForPII(text: []const u8, allocator: std.mem.Allocator) !?[]const u8 { + const findings = try detectPIISimple(text, allocator); + defer allocator.free(findings); + + if (findings.len == 0) { + return null; + } + + var warning = std.ArrayList(u8).initCapacity(allocator, 256) catch |err| { + return err; + }; + defer warning.deinit(allocator); + + const writer = warning.writer(allocator); + try writer.writeAll("Warning: Potential PII detected:\n"); + + for (findings) |finding| { + try writer.print(" - {s}: '{s}'\n", .{ finding.pii_type, finding.matched_text }); + } + + try writer.writeAll("Use --force to store anyway, or edit your text."); + + return try warning.toOwnedSlice(allocator); +} + +/// Redact PII from text for anonymized export +pub fn redactPII(text: []const u8, allocator: std.mem.Allocator) ![]u8 { + const findings = try detectPIISimple(text, allocator); + defer allocator.free(findings); + + if (findings.len == 0) { + return allocator.dupe(u8, text); + } + + // Sort findings by position + std.sort.sort(PIIFinding, findings, {}, compareByStartPos); + + var result = std.ArrayList(u8).initCapacity(allocator, text.len) catch |err| { + return err; + }; + defer result.deinit(allocator); + + var last_end: usize = 0; + var redaction_counter: u32 = 0; + + for (findings) |finding| { + // Append text before this finding + if (finding.start_pos > last_end) { + try result.appendSlice(text[last_end..finding.start_pos]); + } + + // Append redaction placeholder + redaction_counter += 1; + if (std.mem.eql(u8, finding.pii_type, "email")) { + try result.writer(allocator).print("[EMAIL-{d}]", .{redaction_counter}); + } else if (std.mem.eql(u8, finding.pii_type, "ip_address")) { + try result.writer(allocator).print("[IP-{d}]", .{redaction_counter}); + } else { + try result.writer(allocator).print("[REDACTED-{d}]", .{redaction_counter}); + } + + last_end = finding.end_pos; + } + + // Append remaining text + if (last_end < text.len) { + try result.appendSlice(text[last_end..]); + } + + return result.toOwnedSlice(allocator); +} + +fn compareByStartPos(_: void, a: PIIFinding, b: PIIFinding) bool { + return a.start_pos < b.start_pos; +} + +/// Format findings as JSON for API responses +pub fn formatFindingsAsJson(findings: []const PIIFinding, allocator: std.mem.Allocator) ![]u8 { + var buf = std.ArrayList(u8).initCapacity(allocator, 1024) catch |err| { + return err; + }; + defer buf.deinit(allocator); + + const writer = buf.writer(allocator); + try writer.writeAll("["); + + for (findings, 0..) |finding, idx| { + if (idx > 0) try writer.writeAll(","); + try writer.writeAll("{"); + try writer.print("\"type\":\"{s}\",", .{finding.pii_type}); + try writer.print("\"start\":{d},", .{finding.start_pos}); + try writer.print("\"end\":{d},", .{finding.end_pos}); + try writer.writeAll("\"matched\":\""); + // Escape the matched text + for (finding.matched_text) |c| { + switch (c) { + '"' => try writer.writeAll("\\\""), + '\\' => try writer.writeAll("\\\\"), + '\n' => try writer.writeAll("\\n"), + '\r' => try writer.writeAll("\\r"), + '\t' => try writer.writeAll("\\t"), + else => { + if (c < 0x20) { + try writer.print("\\u00{x:0>2}", .{c}); + } else { + try writer.writeByte(c); + } + }, + } + } + try writer.writeAll("\""); + try writer.writeAll("}"); + } + + try writer.writeAll("]"); + return buf.toOwnedSlice(allocator); +} diff --git a/cli/src/utils/suggest.zig b/cli/src/utils/suggest.zig new file mode 100644 index 0000000..f85e167 --- /dev/null +++ b/cli/src/utils/suggest.zig @@ -0,0 +1,264 @@ +const std = @import("std"); + +/// Calculate Levenshtein distance between two strings +pub fn levenshteinDistance(allocator: std.mem.Allocator, s1: []const u8, s2: []const u8) !usize { + const m = s1.len + 1; + const n = s2.len + 1; + + // Create a 2D array for dynamic programming + var dp = try allocator.alloc(usize, m * n); + defer allocator.free(dp); + + // Initialize first row and column + for (0..m) |i| { + dp[i * n] = i; + } + for (0..n) |j| { + dp[j] = j; + } + + // Fill the matrix + for (1..m) |i| { + for (1..n) |j| { + const cost: usize = if (s1[i - 1] == s2[j - 1]) 0 else 1; + const deletion = dp[(i - 1) * n + j] + 1; + const insertion = dp[i * n + (j - 1)] + 1; + const substitution = dp[(i - 1) * n + (j - 1)] + cost; + dp[i * n + j] = @min(@min(deletion, insertion), substitution); + } + } + + return dp[(m - 1) * n + (n - 1)]; +} + +/// Find suggestions for a typo from a list of candidates +pub fn findSuggestions( + allocator: std.mem.Allocator, + input: []const u8, + candidates: []const []const u8, + max_distance: usize, + max_suggestions: usize, +) ![][]const u8 { + var suggestions = std.ArrayList([]const u8).empty; + defer suggestions.deinit(allocator); + + var distances = std.ArrayList(usize).empty; + defer distances.deinit(allocator); + + for (candidates) |candidate| { + const dist = try levenshteinDistance(allocator, input, candidate); + if (dist <= max_distance) { + try suggestions.append(allocator, candidate); + try distances.append(allocator, dist); + } + } + + // Sort by distance (bubble sort for simplicity with small lists) + const n = distances.items.len; + for (0..n) |i| { + for (0..n - i - 1) |j| { + if (distances.items[j] > distances.items[j + 1]) { + // Swap distances + const temp_dist = distances.items[j]; + distances.items[j] = distances.items[j + 1]; + distances.items[j + 1] = temp_dist; + // Swap corresponding suggestions + const temp_sugg = suggestions.items[j]; + suggestions.items[j] = suggestions.items[j + 1]; + suggestions.items[j + 1] = temp_sugg; + } + } + } + + // Return top suggestions + const count = @min(suggestions.items.len, max_suggestions); + const result = try allocator.alloc([]const u8, count); + for (0..count) |i| { + result[i] = try allocator.dupe(u8, suggestions.items[i]); + } + + return result; +} + +/// Suggest commands based on prefix matching +pub fn suggestCommands(input: []const u8) ?[]const []const u8 { + const all_commands = [_][]const u8{ + "init", "sync", "queue", "requeue", "status", + "monitor", "cancel", "prune", "watch", "dataset", + "experiment", "narrative", "outcome", "info", "logs", + "annotate", "validate", "compare", "find", "export", + }; + + // Exact match - no suggestion needed + for (all_commands) |cmd| { + if (std.mem.eql(u8, input, cmd)) return null; + } + + // Find prefix matches + var matches: [5][]const u8 = undefined; + var match_count: usize = 0; + + for (all_commands) |cmd| { + if (std.mem.startsWith(u8, cmd, input)) { + matches[match_count] = cmd; + match_count += 1; + if (match_count >= 5) break; + } + } + + if (match_count == 0) return null; + + // Return static slice - caller must not free + return matches[0..match_count]; +} + +/// Suggest flags for a command +pub fn suggestFlags(command: []const u8, input: []const u8) ?[]const []const u8 { + // Common flags for all commands + const common_flags = [_][]const u8{ "--help", "--verbose", "--quiet", "--json" }; + + // Command-specific flags + const queue_flags = [_][]const u8{ + "--commit", "--priority", "--cpu", "--memory", "--gpu", + "--gpu-memory", "--hypothesis", "--context", "--intent", "--expected-outcome", + "--experiment-group", "--tags", "--dry-run", "--validate", "--explain", + "--force", + }; + + const find_flags = [_][]const u8{ + "--tag", "--outcome", "--dataset", "--experiment-group", + "--author", "--after", "--before", "--limit", + }; + + const compare_flags = [_][]const u8{ + "--json", "--all", "--fields", + }; + + const export_flags = [_][]const u8{ + "--bundle", "--anonymize", "--anonymize-level", "--base", + }; + + // Select flags based on command + const flags: []const []const u8 = switch (std.meta.stringToEnum(Command, command) orelse .unknown) { + .queue => &queue_flags, + .find => &find_flags, + .compare => &compare_flags, + .export_cmd => &export_flags, + else => &common_flags, + }; + + // Find prefix matches + var matches: [5][]const u8 = undefined; + var match_count: usize = 0; + + // Check common flags first + for (common_flags) |flag| { + if (std.mem.startsWith(u8, flag, input)) { + matches[match_count] = flag; + match_count += 1; + if (match_count >= 5) break; + } + } + + // Then check command-specific flags + if (match_count < 5) { + for (flags) |flag| { + if (std.mem.startsWith(u8, flag, input)) { + // Avoid duplicates + var already_added = false; + for (0..match_count) |i| { + if (std.mem.eql(u8, matches[i], flag)) { + already_added = true; + break; + } + } + if (!already_added) { + matches[match_count] = flag; + match_count += 1; + if (match_count >= 5) break; + } + } + } + } + + if (match_count == 0) return null; + return matches[0..match_count]; +} + +const Command = enum { + init, + sync, + queue, + requeue, + status, + monitor, + cancel, + prune, + watch, + dataset, + experiment, + narrative, + outcome, + info, + logs, + annotate, + validate, + compare, + find, + export_cmd, + unknown, +}; + +/// Format suggestions into a helpful message +pub fn formatSuggestionMessage( + allocator: std.mem.Allocator, + input: []const u8, + suggestions: []const []const u8, +) ![]u8 { + if (suggestions.len == 0) return allocator.dupe(u8, ""); + + var buf = std.ArrayList(u8).empty; + defer buf.deinit(allocator); + + const writer = buf.writer(allocator); + try writer.print("Did you mean for '{s}': ", .{input}); + + for (suggestions, 0..) |sugg, i| { + if (i > 0) { + if (i == suggestions.len - 1) { + try writer.writeAll(" or "); + } else { + try writer.writeAll(", "); + } + } + try writer.print("'{s}'", .{sugg}); + } + + try writer.writeAll("?\n"); + + return buf.toOwnedSlice(allocator); +} + +/// Test the suggestion system +pub fn testSuggestions() !void { + const allocator = std.testing.allocator; + + // Test Levenshtein distance + const dist1 = try levenshteinDistance(allocator, "queue", "quee"); + std.debug.assert(dist1 == 1); + + const dist2 = try levenshteinDistance(allocator, "status", "statis"); + std.debug.assert(dist2 == 1); + + // Test suggestions + const candidates = [_][]const u8{ "queue", "query", "quiet", "quit" }; + const suggestions = try findSuggestions(allocator, "quee", &candidates, 2, 3); + defer { + for (suggestions) |s| allocator.free(s); + allocator.free(suggestions); + } + std.debug.assert(suggestions.len > 0); + std.debug.assert(std.mem.eql(u8, suggestions[0], "queue")); + + std.debug.print("Suggestion tests passed!\n", .{}); +} diff --git a/internal/middleware/privacy.go b/internal/middleware/privacy.go new file mode 100644 index 0000000..47f99c9 --- /dev/null +++ b/internal/middleware/privacy.go @@ -0,0 +1,94 @@ +// Package middleware provides privacy enforcement for experiment access control. +package middleware + +import ( + "context" + "fmt" + + "github.com/jfraeys/fetch_ml/internal/auth" +) + +// PrivacyLevel defines experiment visibility levels. +type PrivacyLevel string + +const ( + // PrivacyPrivate restricts access to owner only. + PrivacyPrivate PrivacyLevel = "private" + // PrivacyTeam allows team members to view. + PrivacyTeam PrivacyLevel = "team" + // PrivacyPublic allows all authenticated users. + PrivacyPublic PrivacyLevel = "public" + // PrivacyAnonymized allows access with PII stripped. + PrivacyAnonymized PrivacyLevel = "anonymized" +) + +// PrivacyEnforcer handles privacy access control. +type PrivacyEnforcer struct { + enforceTeams bool + auditAccess bool +} + +// NewPrivacyEnforcer creates a privacy enforcer. +func NewPrivacyEnforcer(enforceTeams, auditAccess bool) *PrivacyEnforcer { + return &PrivacyEnforcer{ + enforceTeams: enforceTeams, + auditAccess: auditAccess, + } +} + +// CanAccess checks if a user can access an experiment. +func (pe *PrivacyEnforcer) CanAccess( + ctx context.Context, + user *auth.User, + experimentOwner string, + level string, + team string, +) (bool, error) { + privacyLevel := GetPrivacyLevelFromString(level) + switch privacyLevel { + case PrivacyPublic: + return true, nil + case PrivacyPrivate: + return user.Name == experimentOwner || user.Admin, nil + case PrivacyTeam: + if user.Name == experimentOwner || user.Admin { + return true, nil + } + if !pe.enforceTeams { + return true, nil // Teams not enforced, allow access + } + // Check if user is in same team + return pe.isUserInTeam(ctx, user, team) + case PrivacyAnonymized: + // Anonymized data is accessible but with PII stripped + return true, nil + default: + return false, fmt.Errorf("unknown privacy level: %s", privacyLevel) + } +} + +func (pe *PrivacyEnforcer) isUserInTeam(ctx context.Context, user *auth.User, team string) (bool, error) { + // TODO: Implement team membership check + // This could query a teams database or use JWT claims + // For now, deny access if teams enforcement is on but no check implemented + _ = ctx + _ = user + _ = team + return false, nil +} + +// GetPrivacyLevelFromString converts string to PrivacyLevel. +func GetPrivacyLevelFromString(level string) PrivacyLevel { + switch level { + case "private": + return PrivacyPrivate + case "team": + return PrivacyTeam + case "public": + return PrivacyPublic + case "anonymized": + return PrivacyAnonymized + default: + return PrivacyPrivate // Default to private for safety + } +} diff --git a/internal/privacy/pii.go b/internal/privacy/pii.go new file mode 100644 index 0000000..6da17d0 --- /dev/null +++ b/internal/privacy/pii.go @@ -0,0 +1,55 @@ +// Package privacy provides PII detection for narratives and annotations. +package privacy + +import ( + "regexp" +) + +// piiPatterns contains regex patterns for detecting PII. +var piiPatterns = map[string]*regexp.Regexp{ + "email": regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`), + "ssn": regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`), + "phone": regexp.MustCompile(`\b\d{3}-\d{3}-\d{4}\b`), + "credit_card": regexp.MustCompile(`\b(?:\d[ -]*?){13,16}\b`), + "ip_address": regexp.MustCompile(`\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b`), +} + +// PIIFinding represents a detected PII instance. +type PIIFinding struct { + Type string `json:"type"` + Position int `json:"position"` + Length int `json:"length"` + Sample string `json:"sample"` // Redacted sample +} + +// DetectPII scans text for potential PII. +func DetectPII(text string) []PIIFinding { + var findings []PIIFinding + + for piiType, pattern := range piiPatterns { + matches := pattern.FindAllStringIndex(text, -1) + for _, match := range matches { + findings = append(findings, PIIFinding{ + Type: piiType, + Position: match[0], + Length: match[1] - match[0], + Sample: RedactSample(text[match[0]:match[1]]), + }) + } + } + + return findings +} + +// HasPII returns true if text contains PII. +func HasPII(text string) bool { + return len(DetectPII(text)) > 0 +} + +// RedactSample creates a safe sample for reporting. +func RedactSample(match string) string { + if len(match) <= 4 { + return "[PII]" + } + return match[:2] + "..." + match[len(match)-2:] +}