Refactor plugins to use interface for testability: - Add PodmanInterface to container package (StartContainer, StopContainer, RemoveContainer) - Update MLflow plugin to use container.PodmanInterface - Update TensorBoard plugin to use container.PodmanInterface - Add comprehensive mocked tests for all three plugins (wandb, mlflow, tensorboard) - Coverage increased from 18% to 91.4%
256 lines
7.4 KiB
Go
256 lines
7.4 KiB
Go
package fileutil_test
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/jfraeys/fetch_ml/internal/fileutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestNewSecurePathValidator tests validator creation
|
|
func TestNewSecurePathValidator(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
validator := fileutil.NewSecurePathValidator("/tmp/test")
|
|
require.NotNil(t, validator)
|
|
assert.Equal(t, "/tmp/test", validator.BasePath)
|
|
}
|
|
|
|
// TestValidatePath tests path validation
|
|
func TestValidatePath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
validator := fileutil.NewSecurePathValidator(tmpDir)
|
|
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
wantErr bool
|
|
}{
|
|
{"valid relative path", "subdir/file.txt", false},
|
|
{"valid nested path", "a/b/c/file.txt", false},
|
|
{"current directory", ".", false},
|
|
{"parent traversal blocked", "../escape", true},
|
|
{"deep parent traversal", "a/../../escape", true},
|
|
{"absolute outside base", "/etc/passwd", true},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
result, err := validator.ValidatePath(tc.path)
|
|
if tc.wantErr {
|
|
require.Error(t, err, "Expected error for path: %s", tc.path)
|
|
} else {
|
|
require.NoError(t, err, "Unexpected error for path: %s", tc.path)
|
|
assert.NotEmpty(t, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidatePathEmptyBase tests validation with empty base
|
|
func TestValidatePathEmptyBase(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
validator := fileutil.NewSecurePathValidator("")
|
|
_, err := validator.ValidatePath("file.txt")
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "base path not set")
|
|
}
|
|
|
|
// TestValidateFileTypeMagicBytes tests file type detection by magic bytes
|
|
func TestValidateFileTypeMagicBytes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
|
|
tests := []struct {
|
|
name string
|
|
content []byte
|
|
filename string
|
|
wantType string
|
|
wantErr bool
|
|
}{
|
|
{"SafeTensors (ZIP)", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", "safetensors", false},
|
|
{"GGUF", []byte{0x47, 0x47, 0x55, 0x46}, "model.gguf", "gguf", false},
|
|
{"HDF5", []byte{0x89, 0x48, 0x44, 0x46}, "model.h5", "hdf5", false},
|
|
{"NumPy", []byte{0x93, 0x4E, 0x55, 0x4D}, "model.npy", "numpy", false},
|
|
{"JSON", []byte(`{"key": "value"}`), "config.json", "json", false},
|
|
{"dangerous extension", []byte{0x00}, "model.pt", "", true},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
path := filepath.Join(tmpDir, tc.filename)
|
|
err := os.WriteFile(path, tc.content, 0644)
|
|
require.NoError(t, err)
|
|
|
|
fileType, err := fileutil.ValidateFileType(path, fileutil.AllAllowedTypes)
|
|
if tc.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.NotNil(t, fileType)
|
|
assert.Equal(t, tc.wantType, fileType.Name)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestIsAllowedExtension tests extension checking
|
|
func TestIsAllowedExtension(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
filename string
|
|
expected bool
|
|
}{
|
|
{"model.safetensors", true},
|
|
{"model.gguf", true},
|
|
{"model.h5", true},
|
|
{"model.npy", true},
|
|
{"config.json", true},
|
|
{"data.csv", true},
|
|
{"config.yaml", true},
|
|
{"readme.txt", true},
|
|
{"model.pt", false}, // Dangerous
|
|
{"model.pkl", false}, // Dangerous
|
|
{"script.sh", false}, // Dangerous
|
|
{"archive.zip", false}, // Dangerous
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.filename, func(t *testing.T) {
|
|
result := fileutil.IsAllowedExtension(tc.filename, fileutil.AllAllowedTypes)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidateDatasetFile tests dataset file validation
|
|
func TestValidateDatasetFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
|
|
tests := []struct {
|
|
name string
|
|
content []byte
|
|
filename string
|
|
wantErr bool
|
|
}{
|
|
{"valid safetensors", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", false},
|
|
{"valid numpy", []byte{0x93, 0x4E, 0x55, 0x4D, 0x00}, "model.npy", false},
|
|
{"valid json", []byte(`{}`), "data.json", false},
|
|
{"dangerous pickle", []byte{0x00}, "model.pkl", true},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
path := filepath.Join(tmpDir, tc.filename)
|
|
err := os.WriteFile(path, tc.content, 0644)
|
|
require.NoError(t, err)
|
|
|
|
err = fileutil.ValidateDatasetFile(path)
|
|
if tc.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidateModelFile tests model file validation
|
|
func TestValidateModelFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tmpDir := t.TempDir()
|
|
|
|
tests := []struct {
|
|
name string
|
|
content []byte
|
|
filename string
|
|
wantErr bool
|
|
}{
|
|
{"valid safetensors", []byte{0x50, 0x4B, 0x03, 0x04}, "model.safetensors", false},
|
|
{"valid gguf", []byte{0x47, 0x47, 0x55, 0x46}, "model.gguf", false},
|
|
{"valid hdf5", []byte{0x89, 0x48, 0x44, 0x46}, "model.h5", false},
|
|
{"valid numpy", []byte{0x93, 0x4E, 0x55, 0x4D, 0x00}, "model.npy", false},
|
|
{"json not model", []byte(`{}`), "data.json", true}, // JSON not in BinaryModelTypes
|
|
{"dangerous pickle", []byte{0x00}, "model.pkl", true},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
path := filepath.Join(tmpDir, tc.filename)
|
|
err := os.WriteFile(path, tc.content, 0644)
|
|
require.NoError(t, err)
|
|
|
|
err = fileutil.ValidateModelFile(path)
|
|
if tc.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidateFileTypeNotFound tests nonexistent file
|
|
func TestValidateFileTypeNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := fileutil.ValidateFileType("/nonexistent/path/file.txt", fileutil.AllAllowedTypes)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to open file")
|
|
}
|
|
|
|
// TestBinaryModelTypes tests the binary model types slice
|
|
func TestBinaryModelTypes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Len(t, fileutil.BinaryModelTypes, 4)
|
|
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.SafeTensors)
|
|
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.GGUF)
|
|
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.HDF5)
|
|
assert.Contains(t, fileutil.BinaryModelTypes, fileutil.NumPy)
|
|
|
|
// JSON, CSV, YAML, Text should NOT be in BinaryModelTypes
|
|
assert.NotContains(t, fileutil.BinaryModelTypes, fileutil.JSON)
|
|
assert.NotContains(t, fileutil.BinaryModelTypes, fileutil.CSV)
|
|
}
|
|
|
|
// TestAllAllowedTypes tests the allowed types slice
|
|
func TestAllAllowedTypes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Len(t, fileutil.AllAllowedTypes, 8)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.SafeTensors)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.GGUF)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.HDF5)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.NumPy)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.JSON)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.CSV)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.YAML)
|
|
assert.Contains(t, fileutil.AllAllowedTypes, fileutil.Text)
|
|
}
|
|
|
|
// TestDangerousExtensions tests dangerous extension list
|
|
func TestDangerousExtensions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".pt")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".pkl")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".pickle")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".pth")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".joblib")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".exe")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".sh")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".zip")
|
|
assert.Contains(t, fileutil.DangerousExtensions, ".tar")
|
|
}
|