fetch_ml/internal/fileutil/fileutil_test.go
Jeremie Fraeys f827ee522a
test(tracking/plugins): add PodmanInterface and comprehensive plugin tests for 91% coverage
Refactor plugins to use interface for testability:
- Add PodmanInterface to container package (StartContainer, StopContainer, RemoveContainer)
- Update MLflow plugin to use container.PodmanInterface
- Update TensorBoard plugin to use container.PodmanInterface
- Add comprehensive mocked tests for all three plugins (wandb, mlflow, tensorboard)
- Coverage increased from 18% to 91.4%
2026-03-14 16:59:16 -04:00

256 lines
7.4 KiB
Go

package fileutil_test
import (
"os"
"path/filepath"
"testing"
"github.com/jfraeys/fetch_ml/internal/fileutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestNewSecurePathValidator tests validator creation
func TestNewSecurePathValidator(t *testing.T) {
t.Parallel()
validator := fileutil.NewSecurePathValidator("/tmp/test")
require.NotNil(t, validator)
assert.Equal(t, "/tmp/test", validator.BasePath)
}
// TestValidatePath tests path validation
func TestValidatePath(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
validator := fileutil.NewSecurePathValidator(tmpDir)
tests := []struct {
name string
path string
wantErr bool
}{
{"valid relative path", "subdir/file.txt", false},
{"valid nested path", "a/b/c/file.txt", false},
{"current directory", ".", false},
{"parent traversal blocked", "../escape", true},
{"deep parent traversal", "a/../../escape", true},
{"absolute outside base", "/etc/passwd", true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result, err := validator.ValidatePath(tc.path)
if tc.wantErr {
require.Error(t, err, "Expected error for path: %s", tc.path)
} else {
require.NoError(t, err, "Unexpected error for path: %s", tc.path)
assert.NotEmpty(t, result)
}
})
}
}
// TestValidatePathEmptyBase tests validation with empty base
func TestValidatePathEmptyBase(t *testing.T) {
t.Parallel()
validator := fileutil.NewSecurePathValidator("")
_, err := validator.ValidatePath("file.txt")
require.Error(t, err)
assert.Contains(t, err.Error(), "base path not set")
}
// TestValidateFileTypeMagicBytes tests file type detection by magic bytes
func TestValidateFileTypeMagicBytes(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
tests := []struct {
name string
content []byte
filename string
wantType string
wantErr bool
}{
{"SafeTensors (ZIP)", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", "safetensors", false},
{"GGUF", []byte{0x47, 0x47, 0x55, 0x46}, "model.gguf", "gguf", false},
{"HDF5", []byte{0x89, 0x48, 0x44, 0x46}, "model.h5", "hdf5", false},
{"NumPy", []byte{0x93, 0x4E, 0x55, 0x4D}, "model.npy", "numpy", false},
{"JSON", []byte(`{"key": "value"}`), "config.json", "json", false},
{"dangerous extension", []byte{0x00}, "model.pt", "", true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(tmpDir, tc.filename)
err := os.WriteFile(path, tc.content, 0644)
require.NoError(t, err)
fileType, err := fileutil.ValidateFileType(path, fileutil.AllAllowedTypes)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NotNil(t, fileType)
assert.Equal(t, tc.wantType, fileType.Name)
}
})
}
}
// TestIsAllowedExtension tests extension checking
func TestIsAllowedExtension(t *testing.T) {
t.Parallel()
tests := []struct {
filename string
expected bool
}{
{"model.safetensors", true},
{"model.gguf", true},
{"model.h5", true},
{"model.npy", true},
{"config.json", true},
{"data.csv", true},
{"config.yaml", true},
{"readme.txt", true},
{"model.pt", false}, // Dangerous
{"model.pkl", false}, // Dangerous
{"script.sh", false}, // Dangerous
{"archive.zip", false}, // Dangerous
}
for _, tc := range tests {
t.Run(tc.filename, func(t *testing.T) {
result := fileutil.IsAllowedExtension(tc.filename, fileutil.AllAllowedTypes)
assert.Equal(t, tc.expected, result)
})
}
}
// TestValidateDatasetFile tests dataset file validation
func TestValidateDatasetFile(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
tests := []struct {
name string
content []byte
filename string
wantErr bool
}{
{"valid safetensors", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", false},
{"valid numpy", []byte{0x93, 0x4E, 0x55, 0x4D, 0x00}, "model.npy", false},
{"valid json", []byte(`{}`), "data.json", false},
{"dangerous pickle", []byte{0x00}, "model.pkl", true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(tmpDir, tc.filename)
err := os.WriteFile(path, tc.content, 0644)
require.NoError(t, err)
err = fileutil.ValidateDatasetFile(path)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
// TestValidateModelFile tests model file validation
func TestValidateModelFile(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
tests := []struct {
name string
content []byte
filename string
wantErr bool
}{
{"valid safetensors", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", false},
{"valid gguf", []byte{0x47, 0x47, 0x55, 0x46}, "model.gguf", false},
{"valid hdf5", []byte{0x89, 0x48, 0x44, 0x46}, "model.h5", false},
{"valid numpy", []byte{0x93, 0x4E, 0x55, 0x4D, 0x00}, "model.npy", false},
{"json not model", []byte(`{}`), "data.json", true}, // JSON not in BinaryModelTypes
{"dangerous pickle", []byte{0x00}, "model.pkl", true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(tmpDir, tc.filename)
err := os.WriteFile(path, tc.content, 0644)
require.NoError(t, err)
err = fileutil.ValidateModelFile(path)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
// TestValidateFileTypeNotFound tests nonexistent file
func TestValidateFileTypeNotFound(t *testing.T) {
t.Parallel()
_, err := fileutil.ValidateFileType("/nonexistent/path/file.txt", fileutil.AllAllowedTypes)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to open file")
}
// TestBinaryModelTypes tests the binary model types slice
func TestBinaryModelTypes(t *testing.T) {
t.Parallel()
assert.Len(t, fileutil.BinaryModelTypes, 4)
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.SafeTensors)
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.GGUF)
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.HDF5)
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.NumPy)
// JSON, CSV, YAML, Text should NOT be in BinaryModelTypes
assert.NotContains(t, fileutil.BinaryModelTypes, fileutil.JSON)
assert.NotContains(t, fileutil.BinaryModelTypes, fileutil.CSV)
}
// TestAllAllowedTypes tests the allowed types slice
func TestAllAllowedTypes(t *testing.T) {
t.Parallel()
assert.Len(t, fileutil.AllAllowedTypes, 8)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.SafeTensors)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.GGUF)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.HDF5)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.NumPy)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.JSON)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.CSV)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.YAML)
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.Text)
}
// TestDangerousExtensions tests dangerous extension list
func TestDangerousExtensions(t *testing.T) {
t.Parallel()
assert.Contains(t, fileutil.DangerousExtensions, ".pt")
assert.Contains(t, fileutil.DangerousExtensions, ".pkl")
assert.Contains(t, fileutil.DangerousExtensions, ".pickle")
assert.Contains(t, fileutil.DangerousExtensions, ".pth")
assert.Contains(t, fileutil.DangerousExtensions, ".joblib")
assert.Contains(t, fileutil.DangerousExtensions, ".exe")
assert.Contains(t, fileutil.DangerousExtensions, ".sh")
assert.Contains(t, fileutil.DangerousExtensions, ".zip")
assert.Contains(t, fileutil.DangerousExtensions, ".tar")
}