fetch_ml/native/streaming_io/streaming_io.cpp
Jeremie Fraeys d408a60eb1
Some checks failed
Documentation / build-and-publish (push) Waiting to run
Test / test (push) Waiting to run
Checkout test / test (push) Successful in 5s
CI with Native Libraries / test-native (push) Has been cancelled
CI with Native Libraries / build-release (push) Has been cancelled
ci: push all workflow updates
2026-02-12 13:28:15 -05:00

281 lines
7.2 KiB
C++

#include "streaming_io.h"
#include <algorithm>
#include <cerrno>
#include <cstring>
#include <fcntl.h>
#include <filesystem>
#include <mutex>
#include <thread>
#include <vector>
#include <zlib.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
// Simplified tar parsing (USTAR format)
// For production, use libarchive or similar
namespace fs = std::filesystem;
struct TarHeader {
char name[100];
char mode[8];
char uid[8];
char gid[8];
char size[12];
char mtime[12];
char checksum[8];
char typeflag;
char linkname[100];
char magic[6];
char version[2];
char uname[32];
char gname[32];
char devmajor[8];
char devminor[8];
char prefix[155];
};
static uint64_t parse_octal(const char* str, size_t len) {
uint64_t result = 0;
for (size_t i = 0; i < len && str[i]; ++i) {
if (str[i] >= '0' && str[i] <= '7') {
result = result * 8 + (str[i] - '0');
}
}
return result;
}
struct ExtractTask {
std::string path;
uint64_t offset;
uint64_t size;
uint32_t mode;
};
struct sio_extractor {
uint32_t num_threads;
sio_progress_cb progress_cb = nullptr;
sio_error_cb error_cb = nullptr;
void* user_data = nullptr;
std::string last_error;
std::mutex error_mutex;
uint64_t bytes_extracted = 0;
uint64_t bytes_written = 0;
std::mutex stats_mutex;
};
sio_extractor_t* sio_create_extractor(uint32_t num_threads) {
auto* ex = new sio_extractor_t;
if (num_threads == 0) {
num_threads = std::thread::hardware_concurrency();
if (num_threads == 0) num_threads = 4;
if (num_threads > 8) num_threads = 8;
}
ex->num_threads = num_threads;
return ex;
}
void sio_destroy_extractor(sio_extractor_t* ex) {
delete ex;
}
void sio_set_progress_cb(sio_extractor_t* ex, sio_progress_cb cb, void* user_data) {
if (ex) {
ex->progress_cb = cb;
ex->user_data = user_data;
}
}
void sio_set_error_cb(sio_extractor_t* ex, sio_error_cb cb, void* user_data) {
if (ex) {
ex->error_cb = cb;
ex->user_data = user_data;
}
}
static void set_error(sio_extractor_t* ex, const char* msg) {
if (ex) {
std::lock_guard<std::mutex> lock(ex->error_mutex);
ex->last_error = msg;
}
}
// Extract a single file from gzipped tar
static int extract_file(const char* archive_path, uint64_t offset, uint64_t size,
const char* dst_path, uint32_t mode, sio_extractor_t* ex) {
// Open archive
gzFile gz = gzopen(archive_path, "rb");
if (!gz) {
set_error(ex, "Failed to open archive");
return -1;
}
// Seek to offset (gzseek is slow but necessary)
if (gzseek(gz, offset, SEEK_SET) < 0) {
gzclose(gz);
set_error(ex, "Failed to seek in archive");
return -1;
}
// Ensure destination directory exists
fs::path dst(dst_path);
fs::create_directories(dst.parent_path());
// Open output file
int fd = open(dst_path, O_WRONLY | O_CREAT | O_TRUNC, mode);
if (fd < 0) {
gzclose(gz);
set_error(ex, "Failed to create output file");
return -1;
}
// Extract data
std::vector<uint8_t> buffer(64 * 1024); // 64KB buffer
uint64_t remaining = size;
uint64_t extracted = 0;
while (remaining > 0) {
size_t to_read = std::min<size_t>(buffer.size(), remaining);
int n = gzread(gz, buffer.data(), to_read);
if (n <= 0) {
close(fd);
gzclose(gz);
set_error(ex, "Failed to read from archive");
return -1;
}
ssize_t written = write(fd, buffer.data(), n);
if (written != n) {
close(fd);
gzclose(gz);
set_error(ex, "Failed to write output file");
return -1;
}
remaining -= n;
extracted += n;
// Update stats
if (ex) {
std::lock_guard<std::mutex> lock(ex->stats_mutex);
ex->bytes_extracted += n;
ex->bytes_written += written;
}
// Progress callback
if (ex && ex->progress_cb) {
ex->progress_cb(dst_path, extracted, size, ex->user_data);
}
}
close(fd);
gzclose(gz);
return 0;
}
int sio_extract_tar_gz(sio_extractor_t* ex, const char* archive_path, const char* dst_dir) {
if (!ex || !archive_path || !dst_dir) return -1;
// Open archive
gzFile gz = gzopen(archive_path, "rb");
if (!gz) {
set_error(ex, "Failed to open archive");
return -1;
}
// Read and parse tar headers
std::vector<ExtractTask> tasks;
uint64_t offset = 0;
while (true) {
TarHeader header;
int n = gzread(gz, &header, sizeof(TarHeader));
if (n != sizeof(TarHeader)) {
break; // End of archive or error
}
// Check for empty block (end of archive)
bool empty = true;
for (size_t i = 0; i < sizeof(TarHeader); ++i) {
if (reinterpret_cast<uint8_t*>(&header)[i] != 0) {
empty = false;
break;
}
}
if (empty) {
break;
}
// Parse header
uint64_t file_size = parse_octal(header.size, 12);
uint32_t mode = parse_octal(header.mode, 8);
// Build full path
std::string path;
if (header.prefix[0]) {
path = std::string(header.prefix) + "/" + header.name;
} else {
path = header.name;
}
// Skip directories and other special files for now
if (header.typeflag == '0' || header.typeflag == '\0') {
ExtractTask task;
task.path = (fs::path(dst_dir) / path).string();
task.offset = offset + sizeof(TarHeader);
task.size = file_size;
task.mode = mode;
tasks.push_back(task);
}
// Skip to next header (file size rounded up to 512 bytes)
uint64_t skip_size = (file_size + 511) & ~511;
if (gzseek(gz, skip_size, SEEK_CUR) < 0) {
break;
}
offset += sizeof(TarHeader) + skip_size;
}
gzclose(gz);
if (tasks.empty()) {
set_error(ex, "No files found in archive");
return -1;
}
// Extract files (single-threaded for now - parallel extraction needs block independence)
for (const auto& task : tasks) {
if (extract_file(archive_path, task.offset, task.size,
task.path.c_str(), task.mode, ex) != 0) {
return -1;
}
}
return 0;
}
const char* sio_last_error(sio_extractor_t* ex) {
if (!ex) return nullptr;
std::lock_guard<std::mutex> lock(ex->error_mutex);
return ex->last_error.empty() ? nullptr : ex->last_error.c_str();
}
uint64_t sio_get_bytes_extracted(sio_extractor_t* ex) {
if (!ex) return 0;
std::lock_guard<std::mutex> lock(ex->stats_mutex);
return ex->bytes_extracted;
}
uint64_t sio_get_bytes_written(sio_extractor_t* ex) {
if (!ex) return 0;
std::lock_guard<std::mutex> lock(ex->stats_mutex);
return ex->bytes_written;
}