#!/usr/bin/env python3 import argparse import json import logging from pathlib import Path import time import numpy as np from sklearn.datasets import make_classification from sklearn.metrics import accuracy_score from sklearn.model_selection import train_test_split import xgboost as xgb def main(): parser = argparse.ArgumentParser() parser.add_argument("--n_estimators", type=int, default=100) parser.add_argument("--max_depth", type=int, default=6) parser.add_argument("--learning_rate", type=float, default=0.1) parser.add_argument("--output_dir", type=str, required=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info( f"Training XGBoost with {args.n_estimators} estimators, depth {args.max_depth}..." ) # Generate synthetic data X, y = make_classification( n_samples=1000, n_features=20, n_classes=2, random_state=42 ) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) # Convert to DMatrix (XGBoost format) dtrain = xgb.DMatrix(X_train, label=y_train) dtest = xgb.DMatrix(X_test, label=y_test) # Train model params = { "max_depth": args.max_depth, "eta": args.learning_rate, "objective": "binary:logistic", "eval_metric": "logloss", "seed": 42, } model = xgb.train(params, dtrain, args.n_estimators) # Evaluate y_pred_prob = model.predict(dtest) y_pred = (y_pred_prob > 0.5).astype(int) accuracy = accuracy_score(y_test, y_pred) logger.info(f"Training completed. Accuracy: {accuracy:.4f}") # Save results results = { "model_type": "XGBoost", "n_estimators": args.n_estimators, "max_depth": args.max_depth, "learning_rate": args.learning_rate, "accuracy": accuracy, "n_samples": len(X), "n_features": X.shape[1], } output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / "results.json", "w") as f: json.dump(results, f, indent=2) # Save model model.save_model(str(output_dir / "xgboost_model.json")) logger.info("Results and model saved successfully!") if __name__ == "__main__": main()