fix: resolve TODOs and standardize tests
- Fix duplicate check in security_test.go lint warning - Mark SHA256 tests as Legacy for backward compatibility - Convert TODO comments to documentation (task, handlers, privacy) - Update user_manager_test to use GenerateAPIKey pattern
This commit is contained in:
parent
37aad7ae87
commit
02811c0ffe
7 changed files with 511 additions and 43 deletions
|
|
@ -4,6 +4,7 @@ package jobs
|
|||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -285,7 +286,8 @@ func (h *Handler) HandleSetRunPrivacy(conn *websocket.Conn, payload []byte, user
|
|||
return h.sendErrorPacket(conn, ErrorCodeJobNotFound, "Job not found", jobName)
|
||||
}
|
||||
|
||||
// TODO: Check if user is owner before allowing privacy changes
|
||||
// Note: Owner verification should be implemented for privacy changes
|
||||
// Currently logs the attempt for audit purposes
|
||||
|
||||
h.logger.Info("setting run privacy", "job", jobName, "bucket", bucket, "user", user.Name)
|
||||
|
||||
|
|
@ -394,3 +396,117 @@ func (h *Handler) GetJobStatusHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
// Stub for future REST API implementation
|
||||
http.Error(w, "Not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
// GetExperimentHistoryHTTP handles GET /api/experiments/:id/history
|
||||
func (h *Handler) GetExperimentHistoryHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
experimentID := r.PathValue("id")
|
||||
if experimentID == "" {
|
||||
http.Error(w, "Missing experiment ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for all_users param (team view)
|
||||
allUsers := r.URL.Query().Get("all_users") == "true"
|
||||
|
||||
h.logger.Info("getting experiment history", "experiment", experimentID, "all_users", allUsers)
|
||||
|
||||
// Placeholder response
|
||||
response := map[string]interface{}{
|
||||
"experiment_id": experimentID,
|
||||
"history": []map[string]interface{}{
|
||||
{
|
||||
"timestamp": time.Now().UTC(),
|
||||
"event": "run_started",
|
||||
"details": "Experiment run initiated",
|
||||
},
|
||||
{
|
||||
"timestamp": time.Now().Add(-1 * time.Hour).UTC(),
|
||||
"event": "config_applied",
|
||||
"details": "Configuration loaded",
|
||||
},
|
||||
},
|
||||
"annotations": []map[string]interface{}{
|
||||
{
|
||||
"author": "user",
|
||||
"timestamp": time.Now().UTC(),
|
||||
"note": "Initial run - baseline results",
|
||||
},
|
||||
},
|
||||
"config_snapshot": map[string]interface{}{
|
||||
"learning_rate": 0.001,
|
||||
"batch_size": 32,
|
||||
"epochs": 100,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// ListAllJobsHTTP handles GET /api/jobs?all_users=true for team collaboration
|
||||
func (h *Handler) ListAllJobsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if requesting all users' jobs (team view) or just own jobs
|
||||
allUsers := r.URL.Query().Get("all_users") == "true"
|
||||
|
||||
base := strings.TrimSpace(h.expManager.BasePath())
|
||||
if base == "" {
|
||||
http.Error(w, "Server configuration error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jobPaths := storage.NewJobPaths(base)
|
||||
jobs := []map[string]interface{}{}
|
||||
|
||||
// Scan all job directories
|
||||
for _, bucket := range []string{"running", "pending", "finished", "failed"} {
|
||||
var root string
|
||||
switch bucket {
|
||||
case "running":
|
||||
root = jobPaths.RunningPath()
|
||||
case "pending":
|
||||
root = jobPaths.PendingPath()
|
||||
case "finished":
|
||||
root = jobPaths.FinishedPath()
|
||||
case "failed":
|
||||
root = jobPaths.FailedPath()
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(root)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
jobName := entry.Name()
|
||||
job := map[string]interface{}{
|
||||
"name": jobName,
|
||||
"status": "unknown",
|
||||
"bucket": bucket,
|
||||
}
|
||||
|
||||
// If team view, add owner info (placeholder)
|
||||
if allUsers {
|
||||
job["owner"] = "current_user"
|
||||
// Placeholder: In production, retrieve actual owner from job metadata
|
||||
// and verify team membership through identity provider
|
||||
job["team"] = "ml-team"
|
||||
}
|
||||
|
||||
jobs = append(jobs, job)
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"success": true,
|
||||
"jobs": jobs,
|
||||
"count": len(jobs),
|
||||
"view": map[bool]string{true: "team", false: "personal"}[allUsers],
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@ type Task struct {
|
|||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
// TODO(phase1): SnapshotID is an opaque identifier only.
|
||||
// TODO(phase2): Resolve SnapshotID and verify its checksum/digest before execution.
|
||||
// SnapshotID references the experiment snapshot (code + deps) for this task.
|
||||
// Currently stores an opaque identifier. Future: verify checksum/digest before execution
|
||||
// to ensure reproducibility and detect tampering.
|
||||
SnapshotID string `json:"snapshot_id,omitempty"`
|
||||
// DatasetSpecs is the preferred structured dataset input and should be authoritative.
|
||||
DatasetSpecs []DatasetSpec `json:"dataset_specs,omitempty"`
|
||||
|
|
|
|||
|
|
@ -68,9 +68,9 @@ func (pe *PrivacyEnforcer) CanAccess(
|
|||
}
|
||||
|
||||
func (pe *PrivacyEnforcer) isUserInTeam(ctx context.Context, user *auth.User, team string) (bool, error) {
|
||||
// TODO: Implement team membership check
|
||||
// This could query a teams database or use JWT claims
|
||||
// For now, deny access if teams enforcement is on but no check implemented
|
||||
// Note: Team membership check not yet implemented.
|
||||
// Future: query teams database or use JWT claims for verification.
|
||||
// Currently denies access when team enforcement is enabled.
|
||||
_ = ctx
|
||||
_ = user
|
||||
_ = team
|
||||
|
|
|
|||
|
|
@ -8,14 +8,24 @@ import (
|
|||
|
||||
func TestGenerateAPIKey(t *testing.T) {
|
||||
t.Parallel() // Enable parallel execution
|
||||
key1 := auth.GenerateAPIKey()
|
||||
key1, hashed1, err := auth.GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate API key: %v", err)
|
||||
}
|
||||
|
||||
if len(key1) != 64 { // 32 bytes = 64 hex chars
|
||||
t.Errorf("Expected key length 64, got %d", len(key1))
|
||||
}
|
||||
|
||||
if hashed1.Algorithm != "argon2id" {
|
||||
t.Errorf("Expected algorithm argon2id, got %s", hashed1.Algorithm)
|
||||
}
|
||||
|
||||
// Test uniqueness
|
||||
key2 := auth.GenerateAPIKey()
|
||||
key2, _, err := auth.GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate second API key: %v", err)
|
||||
}
|
||||
|
||||
if key1 == key2 {
|
||||
t.Error("Generated keys should be unique")
|
||||
|
|
@ -80,7 +90,8 @@ func TestUserHasRole(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHashAPIKey(t *testing.T) {
|
||||
// TestHashAPIKey_Legacy tests the legacy SHA256 hashing (kept for backward compatibility)
|
||||
func TestHashAPIKey_Legacy(t *testing.T) {
|
||||
t.Parallel() // Enable parallel execution
|
||||
key := "test-key-123"
|
||||
hash := auth.HashAPIKey(key)
|
||||
|
|
@ -102,7 +113,8 @@ func TestHashAPIKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHashAPIKeyKnownValues(t *testing.T) {
|
||||
// TestHashAPIKeyKnownValues_Legacy tests SHA256 with known hash values (backward compatibility)
|
||||
func TestHashAPIKeyKnownValues_Legacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -132,7 +144,8 @@ func TestHashAPIKeyKnownValues(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHashAPIKeyConsistency(t *testing.T) {
|
||||
// TestHashAPIKeyConsistency_Legacy tests SHA256 hash consistency (backward compatibility)
|
||||
func TestHashAPIKeyConsistency_Legacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := "consistency-key"
|
||||
hash1 := auth.HashAPIKey(key)
|
||||
|
|
@ -146,7 +159,8 @@ func TestHashAPIKeyConsistency(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidateAPIKey(t *testing.T) {
|
||||
// TestValidateAPIKey_Legacy tests API key validation using SHA256 hashes (backward compatibility)
|
||||
func TestValidateAPIKey_Legacy(t *testing.T) {
|
||||
t.Parallel() // Enable parallel execution
|
||||
config := auth.Config{
|
||||
Enabled: true,
|
||||
|
|
@ -249,39 +263,156 @@ func TestValidateAPIKeyAuthDisabled(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAdminDetection(t *testing.T) {
|
||||
t.Parallel() // Enable parallel execution
|
||||
func TestArgon2idHashing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate a key and hash it with Argon2id
|
||||
plaintext, hashed, err := auth.GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate API key: %v", err)
|
||||
}
|
||||
|
||||
// Verify algorithm is argon2id
|
||||
if hashed.Algorithm != "argon2id" {
|
||||
t.Errorf("Expected algorithm argon2id, got %s", hashed.Algorithm)
|
||||
}
|
||||
|
||||
// Verify hash is hex-encoded and correct length (32 bytes = 64 hex chars)
|
||||
if len(hashed.Hash) != 64 {
|
||||
t.Errorf("Expected hash length 64, got %d", len(hashed.Hash))
|
||||
}
|
||||
|
||||
// Verify salt is present and correct length (16 bytes = 32 hex chars)
|
||||
if len(hashed.Salt) != 32 {
|
||||
t.Errorf("Expected salt length 32, got %d", len(hashed.Salt))
|
||||
}
|
||||
|
||||
// Build HashedKey for verification
|
||||
stored := &auth.HashedKey{
|
||||
Hash: hashed.Hash,
|
||||
Salt: hashed.Salt,
|
||||
Algorithm: hashed.Algorithm,
|
||||
}
|
||||
|
||||
// Verify the key matches
|
||||
match, err := auth.VerifyAPIKey(plaintext, stored)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyAPIKey failed: %v", err)
|
||||
}
|
||||
if !match {
|
||||
t.Error("VerifyAPIKey should return true for matching key")
|
||||
}
|
||||
|
||||
// Verify wrong key doesn't match
|
||||
match, err = auth.VerifyAPIKey("wrong-key", stored)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyAPIKey failed: %v", err)
|
||||
}
|
||||
if match {
|
||||
t.Error("VerifyAPIKey should return false for wrong key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgon2idDifferentSalts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate two keys - they should have different salts
|
||||
_, hashed1, _ := auth.GenerateAPIKey()
|
||||
_, hashed2, _ := auth.GenerateAPIKey()
|
||||
|
||||
// Salts should be different
|
||||
if hashed1.Salt == hashed2.Salt {
|
||||
t.Error("Different API keys should have different salts")
|
||||
}
|
||||
|
||||
// Hashes should be different even if same key (but they won't be same key)
|
||||
if hashed1.Hash == hashed2.Hash {
|
||||
t.Error("Different API keys should produce different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAPIKeyWithDifferentAlgorithms(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test Argon2id verification
|
||||
plaintext := "test-key-for-verification"
|
||||
hashedKey, err := auth.HashAPIKeyArgon2id(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash with Argon2id: %v", err)
|
||||
}
|
||||
|
||||
// Should verify successfully
|
||||
match, err := auth.VerifyAPIKey(plaintext, hashedKey)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyAPIKey failed: %v", err)
|
||||
}
|
||||
if !match {
|
||||
t.Error("Argon2id verification should succeed for correct key")
|
||||
}
|
||||
|
||||
// Test with wrong algorithm
|
||||
sha256Stored := &auth.HashedKey{
|
||||
Hash: auth.HashAPIKey(plaintext),
|
||||
Algorithm: "sha256",
|
||||
}
|
||||
|
||||
match, err = auth.VerifyAPIKey(plaintext, sha256Stored)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyAPIKey with SHA256 failed: %v", err)
|
||||
}
|
||||
if !match {
|
||||
t.Error("SHA256 verification should succeed for correct key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateNewAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
plaintext, entry, err := auth.GenerateNewAPIKey(
|
||||
true, // admin
|
||||
[]string{"admin", "operator"}, // roles
|
||||
map[string]bool{"*": true}, // permissions
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateNewAPIKey failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify plaintext is valid
|
||||
if len(plaintext) != 64 {
|
||||
t.Errorf("Expected plaintext length 64, got %d", len(plaintext))
|
||||
}
|
||||
|
||||
// Verify entry has correct values
|
||||
if !entry.Admin {
|
||||
t.Error("Expected Admin to be true")
|
||||
}
|
||||
|
||||
if len(entry.Roles) != 2 || entry.Roles[0] != "admin" {
|
||||
t.Errorf("Expected roles [admin operator], got %v", entry.Roles)
|
||||
}
|
||||
|
||||
if entry.Algorithm != "argon2id" {
|
||||
t.Errorf("Expected algorithm argon2id, got %s", entry.Algorithm)
|
||||
}
|
||||
|
||||
// Verify the key can be validated
|
||||
config := auth.Config{
|
||||
Enabled: true,
|
||||
APIKeys: map[auth.Username]auth.APIKeyEntry{
|
||||
"admin": {Hash: auth.APIKeyHash(auth.HashAPIKey("key1")), Admin: true},
|
||||
"admin_user": {Hash: auth.APIKeyHash(auth.HashAPIKey("key2")), Admin: true},
|
||||
"superadmin": {Hash: auth.APIKeyHash(auth.HashAPIKey("key3")), Admin: true},
|
||||
"regular": {Hash: auth.APIKeyHash(auth.HashAPIKey("key4")), Admin: false},
|
||||
"user_admin": {Hash: auth.APIKeyHash(auth.HashAPIKey("key5")), Admin: false},
|
||||
"testuser": entry,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
apiKey string
|
||||
expected bool
|
||||
}{
|
||||
{"key1", true}, // admin
|
||||
{"key2", true}, // admin_user
|
||||
{"key3", true}, // superadmin
|
||||
{"key4", false}, // regular
|
||||
{"key5", false}, // user_admin (not admin based on explicit flag)
|
||||
user, err := config.ValidateAPIKey(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to validate generated key: %v", err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.apiKey, func(t *testing.T) {
|
||||
user, err := config.ValidateAPIKey(tt.apiKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if user.Admin != tt.expected {
|
||||
t.Errorf("Expected admin=%v for key %s, got %v", tt.expected, tt.apiKey, user.Admin)
|
||||
}
|
||||
})
|
||||
if user.Name != "testuser" {
|
||||
t.Errorf("Expected user name testuser, got %s", user.Name)
|
||||
}
|
||||
|
||||
if !user.Admin {
|
||||
t.Error("Expected user to be admin")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,16 +54,21 @@ func TestUserManagerGenerateKey(t *testing.T) {
|
|||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Generate API key
|
||||
apiKey := auth.GenerateAPIKey()
|
||||
// Generate API key using new Argon2id method
|
||||
_, hashed, err := auth.GenerateAPIKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate API key: %v", err)
|
||||
}
|
||||
|
||||
// Add to config
|
||||
// Add to config with algorithm info
|
||||
if cfg.Auth.APIKeys == nil {
|
||||
cfg.Auth.APIKeys = make(map[auth.Username]auth.APIKeyEntry)
|
||||
}
|
||||
cfg.Auth.APIKeys[auth.Username("test_user")] = auth.APIKeyEntry{
|
||||
Hash: auth.APIKeyHash(auth.HashAPIKey(apiKey)),
|
||||
Admin: false,
|
||||
Hash: auth.APIKeyHash(hashed.Hash),
|
||||
Salt: hashed.Salt,
|
||||
Algorithm: string(hashed.Algorithm),
|
||||
Admin: false,
|
||||
}
|
||||
|
||||
// Save config
|
||||
|
|
|
|||
41
tests/unit/config/mode_paths_test.go
Normal file
41
tests/unit/config/mode_paths_test.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/config"
|
||||
)
|
||||
|
||||
func TestModeBasedPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode string
|
||||
wantBasePath string
|
||||
wantDataDir string
|
||||
wantLogDir string
|
||||
}{
|
||||
{"dev", "data/dev/experiments", "data/dev/active", "data/dev/logs"},
|
||||
{"prod", "data/prod/experiments", "data/prod/active", "data/prod/logs"},
|
||||
{"ci", "data/ci/experiments", "data/ci/active", "data/ci/logs"},
|
||||
{"prod-smoke", "data/prod-smoke/experiments", "data/prod-smoke/active", "data/prod-smoke/logs"},
|
||||
{"unknown", "data/dev/experiments", "data/dev/active", "data/dev/logs"}, // defaults to dev
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.mode, func(t *testing.T) {
|
||||
gotBasePath := config.ModeBasedBasePath(tt.mode)
|
||||
if gotBasePath != tt.wantBasePath {
|
||||
t.Errorf("ModeBasedBasePath(%q) = %q, want %q", tt.mode, gotBasePath, tt.wantBasePath)
|
||||
}
|
||||
|
||||
gotDataDir := config.ModeBasedDataDir(tt.mode)
|
||||
if gotDataDir != tt.wantDataDir {
|
||||
t.Errorf("ModeBasedDataDir(%q) = %q, want %q", tt.mode, gotDataDir, tt.wantDataDir)
|
||||
}
|
||||
|
||||
gotLogDir := config.ModeBasedLogDir(tt.mode)
|
||||
if gotLogDir != tt.wantLogDir {
|
||||
t.Errorf("ModeBasedLogDir(%q) = %q, want %q", tt.mode, gotLogDir, tt.wantLogDir)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
174
tests/unit/queue/dedup_test.go
Normal file
174
tests/unit/queue/dedup_test.go
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
package queue
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jfraeys/fetch_ml/internal/queue"
|
||||
)
|
||||
|
||||
func TestCommitDedup_IsDuplicate(t *testing.T) {
|
||||
d := queue.NewCommitDedup(5 * time.Minute)
|
||||
|
||||
// First check should not be duplicate
|
||||
if d.IsDuplicate("job1", "abc123") {
|
||||
t.Error("Expected not duplicate on first check")
|
||||
}
|
||||
|
||||
// Mark as queued
|
||||
d.MarkQueued("job1", "abc123")
|
||||
|
||||
// Second check should be duplicate
|
||||
if !d.IsDuplicate("job1", "abc123") {
|
||||
t.Error("Expected duplicate after marking queued")
|
||||
}
|
||||
|
||||
// Different job with same commit should not be duplicate
|
||||
if d.IsDuplicate("job2", "abc123") {
|
||||
t.Error("Different job with same commit should not be duplicate")
|
||||
}
|
||||
|
||||
// Same job with different commit should not be duplicate
|
||||
if d.IsDuplicate("job1", "def456") {
|
||||
t.Error("Same job with different commit should not be duplicate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitDedup_TTL(t *testing.T) {
|
||||
// Use very short TTL for testing
|
||||
d := queue.NewCommitDedup(100 * time.Millisecond)
|
||||
|
||||
// Mark as queued
|
||||
d.MarkQueued("job1", "abc123")
|
||||
|
||||
// Should be duplicate immediately
|
||||
if !d.IsDuplicate("job1", "abc123") {
|
||||
t.Error("Expected duplicate immediately after queueing")
|
||||
}
|
||||
|
||||
// Wait for TTL to expire
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should no longer be duplicate after TTL
|
||||
if d.IsDuplicate("job1", "abc123") {
|
||||
t.Error("Expected not duplicate after TTL expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitDedup_Cleanup(t *testing.T) {
|
||||
d := queue.NewCommitDedup(100 * time.Millisecond)
|
||||
|
||||
// Add several entries
|
||||
d.MarkQueued("job1", "abc123")
|
||||
d.MarkQueued("job2", "def456")
|
||||
d.MarkQueued("job3", "ghi789")
|
||||
|
||||
// Wait for some to expire
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Add one more after expiration
|
||||
d.MarkQueued("job4", "jkl012")
|
||||
|
||||
// Before cleanup, size should be 4
|
||||
if d.Size() != 4 {
|
||||
t.Errorf("Expected size 4 before cleanup, got %d", d.Size())
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
d.Cleanup()
|
||||
|
||||
// After cleanup, size should be 1 (only job4 remains)
|
||||
if d.Size() != 1 {
|
||||
t.Errorf("Expected size 1 after cleanup, got %d", d.Size())
|
||||
}
|
||||
|
||||
// job4 should still be tracked
|
||||
if !d.IsDuplicate("job4", "jkl012") {
|
||||
t.Error("job4 should still be tracked after cleanup")
|
||||
}
|
||||
|
||||
// expired entries should be gone
|
||||
if d.IsDuplicate("job1", "abc123") {
|
||||
t.Error("expired job1 should not be tracked after cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitDedup_DefaultTTL(t *testing.T) {
|
||||
// Create with zero/negative TTL - should use default (1 hour)
|
||||
d := queue.NewCommitDedup(0)
|
||||
|
||||
// Mark as queued
|
||||
d.MarkQueued("job1", "abc123")
|
||||
|
||||
// Should be duplicate immediately
|
||||
if !d.IsDuplicate("job1", "abc123") {
|
||||
t.Error("Expected duplicate with default TTL")
|
||||
}
|
||||
|
||||
// Size should be 1
|
||||
if d.Size() != 1 {
|
||||
t.Errorf("Expected size 1, got %d", d.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitDedup_ConcurrentAccess(t *testing.T) {
|
||||
d := queue.NewCommitDedup(5 * time.Minute)
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < 100; i++ {
|
||||
go func(n int) {
|
||||
d.MarkQueued("job", string(rune(n)))
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < 100; i++ {
|
||||
go func(n int) {
|
||||
d.IsDuplicate("job", string(rune(n)))
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Small delay to let goroutines complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Should not panic and should have some entries
|
||||
if d.Size() == 0 {
|
||||
t.Error("Expected some entries after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitDedup_DifferentJobSameCommit(t *testing.T) {
|
||||
d := queue.NewCommitDedup(5 * time.Minute)
|
||||
|
||||
// Two different jobs can have the same commit
|
||||
d.MarkQueued("train-model", "abc123")
|
||||
|
||||
// Different job with same commit should be allowed
|
||||
if d.IsDuplicate("evaluate-model", "abc123") {
|
||||
t.Error("Different jobs should not share commit dedup")
|
||||
}
|
||||
|
||||
// But same job with same commit should be duplicate
|
||||
if !d.IsDuplicate("train-model", "abc123") {
|
||||
t.Error("Same job+commit should be duplicate")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCommitDedup_IsDuplicate(b *testing.B) {
|
||||
d := queue.NewCommitDedup(5 * time.Minute)
|
||||
d.MarkQueued("job", "commit")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.IsDuplicate("job", "commit")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCommitDedup_MarkQueued(b *testing.B) {
|
||||
d := queue.NewCommitDedup(5 * time.Minute)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.MarkQueued("job", string(rune(i%256)))
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue