package fileutil_test import ( "os" "path/filepath" "testing" "github.com/jfraeys/fetch_ml/internal/fileutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestNewSecurePathValidator tests validator creation func TestNewSecurePathValidator(t *testing.T) { t.Parallel() validator := fileutil.NewSecurePathValidator("/tmp/test") require.NotNil(t, validator) assert.Equal(t, "/tmp/test", validator.BasePath) } // TestValidatePath tests path validation func TestValidatePath(t *testing.T) { t.Parallel() tmpDir := t.TempDir() validator := fileutil.NewSecurePathValidator(tmpDir) tests := []struct { name string path string wantErr bool }{ {"valid relative path", "subdir/file.txt", false}, {"valid nested path", "a/b/c/file.txt", false}, {"current directory", ".", false}, {"parent traversal blocked", "../escape", true}, {"deep parent traversal", "a/../../escape", true}, {"absolute outside base", "/etc/passwd", true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result, err := validator.ValidatePath(tc.path) if tc.wantErr { require.Error(t, err, "Expected error for path: %s", tc.path) } else { require.NoError(t, err, "Unexpected error for path: %s", tc.path) assert.NotEmpty(t, result) } }) } } // TestValidatePathEmptyBase tests validation with empty base func TestValidatePathEmptyBase(t *testing.T) { t.Parallel() validator := fileutil.NewSecurePathValidator("") _, err := validator.ValidatePath("file.txt") require.Error(t, err) assert.Contains(t, err.Error(), "base path not set") } // TestValidateFileTypeMagicBytes tests file type detection by magic bytes func TestValidateFileTypeMagicBytes(t *testing.T) { t.Parallel() tmpDir := t.TempDir() tests := []struct { name string content []byte filename string wantType string wantErr bool }{ {"SafeTensors (ZIP)", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", "safetensors", false}, {"GGUF", []byte{0x47, 0x47, 0x55, 0x46}, "model.gguf", "gguf", false}, {"HDF5", []byte{0x89, 0x48, 0x44, 0x46}, "model.h5", "hdf5", false}, {"NumPy", []byte{0x93, 0x4E, 0x55, 0x4D}, "model.npy", "numpy", false}, {"JSON", []byte(`{"key": "value"}`), "config.json", "json", false}, {"dangerous extension", []byte{0x00}, "model.pt", "", true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { path := filepath.Join(tmpDir, tc.filename) err := os.WriteFile(path, tc.content, 0644) require.NoError(t, err) fileType, err := fileutil.ValidateFileType(path, fileutil.AllAllowedTypes) if tc.wantErr { require.Error(t, err) } else { require.NoError(t, err) require.NotNil(t, fileType) assert.Equal(t, tc.wantType, fileType.Name) } }) } } // TestIsAllowedExtension tests extension checking func TestIsAllowedExtension(t *testing.T) { t.Parallel() tests := []struct { filename string expected bool }{ {"model.safetensors", true}, {"model.gguf", true}, {"model.h5", true}, {"model.npy", true}, {"config.json", true}, {"data.csv", true}, {"config.yaml", true}, {"readme.txt", true}, {"model.pt", false}, // Dangerous {"model.pkl", false}, // Dangerous {"script.sh", false}, // Dangerous {"archive.zip", false}, // Dangerous } for _, tc := range tests { t.Run(tc.filename, func(t *testing.T) { result := fileutil.IsAllowedExtension(tc.filename, fileutil.AllAllowedTypes) assert.Equal(t, tc.expected, result) }) } } // TestValidateDatasetFile tests dataset file validation func TestValidateDatasetFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() tests := []struct { name string content []byte filename string wantErr bool }{ {"valid safetensors", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", false}, {"valid numpy", []byte{0x93, 0x4E, 0x55, 0x4D, 0x00}, "model.npy", false}, {"valid json", []byte(`{}`), "data.json", false}, {"dangerous pickle", []byte{0x00}, "model.pkl", true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { path := filepath.Join(tmpDir, tc.filename) err := os.WriteFile(path, tc.content, 0644) require.NoError(t, err) err = fileutil.ValidateDatasetFile(path) if tc.wantErr { require.Error(t, err) } else { require.NoError(t, err) } }) } } // TestValidateModelFile tests model file validation func TestValidateModelFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() tests := []struct { name string content []byte filename string wantErr bool }{ {"valid safetensors", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", false}, {"valid gguf", []byte{0x47, 0x47, 0x55, 0x46}, "model.gguf", false}, {"valid hdf5", []byte{0x89, 0x48, 0x44, 0x46}, "model.h5", false}, {"valid numpy", []byte{0x93, 0x4E, 0x55, 0x4D, 0x00}, "model.npy", false}, {"json not model", []byte(`{}`), "data.json", true}, // JSON not in BinaryModelTypes {"dangerous pickle", []byte{0x00}, "model.pkl", true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { path := filepath.Join(tmpDir, tc.filename) err := os.WriteFile(path, tc.content, 0644) require.NoError(t, err) err = fileutil.ValidateModelFile(path) if tc.wantErr { require.Error(t, err) } else { require.NoError(t, err) } }) } } // TestValidateFileTypeNotFound tests nonexistent file func TestValidateFileTypeNotFound(t *testing.T) { t.Parallel() _, err := fileutil.ValidateFileType("/nonexistent/path/file.txt", fileutil.AllAllowedTypes) require.Error(t, err) assert.Contains(t, err.Error(), "failed to open file") } // TestBinaryModelTypes tests the binary model types slice func TestBinaryModelTypes(t *testing.T) { t.Parallel() assert.Len(t, fileutil.BinaryModelTypes, 4) assert.Contains(t, fileutil.BinaryModelTypes, fileutil.SafeTensors) assert.Contains(t, fileutil.BinaryModelTypes, fileutil.GGUF) assert.Contains(t, fileutil.BinaryModelTypes, fileutil.HDF5) assert.Contains(t, fileutil.BinaryModelTypes, fileutil.NumPy) // JSON, CSV, YAML, Text should NOT be in BinaryModelTypes assert.NotContains(t, fileutil.BinaryModelTypes, fileutil.JSON) assert.NotContains(t, fileutil.BinaryModelTypes, fileutil.CSV) } // TestAllAllowedTypes tests the allowed types slice func TestAllAllowedTypes(t *testing.T) { t.Parallel() assert.Len(t, fileutil.AllAllowedTypes, 8) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.SafeTensors) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.GGUF) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.HDF5) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.NumPy) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.JSON) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.CSV) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.YAML) assert.Contains(t, fileutil.AllAllowedTypes, fileutil.Text) } // TestDangerousExtensions tests dangerous extension list func TestDangerousExtensions(t *testing.T) { t.Parallel() assert.Contains(t, fileutil.DangerousExtensions, ".pt") assert.Contains(t, fileutil.DangerousExtensions, ".pkl") assert.Contains(t, fileutil.DangerousExtensions, ".pickle") assert.Contains(t, fileutil.DangerousExtensions, ".pth") assert.Contains(t, fileutil.DangerousExtensions, ".joblib") assert.Contains(t, fileutil.DangerousExtensions, ".exe") assert.Contains(t, fileutil.DangerousExtensions, ".sh") assert.Contains(t, fileutil.DangerousExtensions, ".zip") assert.Contains(t, fileutil.DangerousExtensions, ".tar") }