{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/hep/davide/miniconda3/envs/root_env/lib/python2.7/site-packages/root_numpy/_tree.py:5: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility\n", " from . import _librootnumpy\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import os\n", "import pickle\n", "import math\n", "\n", "from sklearn.metrics import accuracy_score, roc_auc_score\n", "\n", "from xgboost import XGBClassifier\n", "from tools.data_processing import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# IMPORTING THE DATASET" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "l_index = 1\n", "mother_ID=['Dplus','Ds']\n", "l_flv = ['e','mu']\n", "PATH='/disk/lhcb_data/davide/Rphipi_new/'\n", "n_cats = 6" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Bkg data amounts to 9660 while signal MC amounts to 5521 Ds and 9109 Dplus samples\n" ] } ], "source": [ "MC_Dplus_sig_dict, MC_Ds_sig_dict, data_bkg_dict = load_datasets(l_index, PATH)\n", "MC_Dplus_sig_dict, MC_Ds_sig_dict, data_bkg_dict = norm_chi2(MC_Dplus_sig_dict, MC_Ds_sig_dict, data_bkg_dict)\n", "\n", "m_plus=MC_Dplus_sig_dict[\"Dplus_ConsD_M\"].shape[0]\n", "m_s=MC_Ds_sig_dict[\"Ds_ConsD_M\"].shape[0]\n", "n=data_bkg_dict[\"Ds_ConsD_M\"].shape[0]\n", "\n", "#Number of input features\n", "\n", "print('Bkg data amounts to {0} while signal MC amounts to {1} Ds and {2} Dplus samples'.format(n,m_s,m_plus))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFNhJREFUeJzt3X+s3fV93/Hnq3ZxtB9xAjhbZyDXESSaWdr88Mg6JW00lNSka5wssJhNmrcgoWxFaxtVm1FVRFgrhXQbWxekDg00QtdBRrbpbjjy2GhSbSKOLwkJcZibi8OGC0pNjNzRlhAn7/1xvi6Hw72+33vvuefc68/zIR35+/18P9973ufcr1/f7/18v+d7UlVIktrwQ9MuQJI0OYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGbp13AqAsvvLBmZmamXYYkbSiPPPLIs1W1bal+6y70Z2ZmmJubm3YZkrShJPk/ffo5vCNJDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9LUuzOx/YNolSE0w9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGvrSR3bx12hVogzH0NXXegkGanF6hn2R3kqNJ5pPsX2D5liT3dcsPJZkZWvajSR5OciTJY0leNb7yJUnLsWToJ9kE3A5cBewErk2yc6TbdcBzVXUpcBtwa7fuZuA3gY9W1eXAu4Hvja16nVNm9j/gUb+0xvoc6V8BzFfVsap6EbgX2DPSZw9wdzd9P3BlkgDvBb5WVV8FqKrvVNX3x1O6JGm5+oT+duCpofnjXduCfarqNHAKuAB4I1BJDib5cpJ/tNATJLk+yVySuRMnTiz3NUiSeuoT+lmgrXr22Qy8E/jb3b8fTHLlKzpW3VFVu6pq17Zt23qUJElaiT6hfxy4eGj+IuDpxfp04/hbgZNd+xeq6tmq+iPgAPC21RYtSVqZPqF/GLgsyY4k5wF7gdmRPrPAvm76auChqirgIPCjSf5UtzP4SeAb4yldkrRcm5fqUFWnk9zAIMA3AXdV1ZEktwBzVTUL3Anck2SewRH+3m7d55L8cwY7jgIOVJWXZ0jSlCwZ+gBVdYDB0Mxw201D0y8A1yyy7m8yuGxTkjRlfiJXkhpi6EtSQwx9SWpIrzF9SeuId9bUKnikL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0NVV+Ebo0WYa+JDXE0Jekhhj6ktQQQ1/rjuP80tox9KWNzlstaxkMfUlqiKEvSQ0x9CWpIYa+JDWkV+gn2Z3kaJL5JPsXWL4lyX3d8kNJZrr2mSR/nOTR7vEb4y1faownbbVKS34xepJNwO3Ae4DjwOEks1X1jaFu1wHPVdWlSfYCtwIf7pY9UVVvGXPdkqQV6HOkfwUwX1XHqupF4F5gz0ifPcDd3fT9wJVJMr4yJUnj0Cf0twNPDc0f79oW7FNVp4FTwAXdsh1JvpLkC0netdATJLk+yVySuRMnTizrBUiS+usT+gsdsVfPPs8Al1TVW4GPAb+V5NWv6Fh1R1Xtqqpd27Zt61GSJGkl+oT+ceDiofmLgKcX65NkM7AVOFlV362q7wBU1SPAE8AbV1u0JGll+oT+YeCyJDuSnAfsBWZH+swC+7rpq4GHqqqSbOtOBJPkDcBlwLHxlC5JWq4lr96pqtNJbgAOApuAu6rqSJJbgLmqmgXuBO5JMg+cZLBjAPgJ4JYkp4HvAx+tqpNr8UIkSUtbMvQBquoAcGCk7aah6ReAaxZY77PAZ1dZoyRpTPxEriQ1xNCXpIYY+pLUEENfkhpi6Gtq/FpEafIMfUlqiKEvSQ0x9LUuOfQjrQ1DX5IaYuhLUkMMfUlqiKEvbRR+P67GwNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1JBeoZ9kd5KjSeaT7F9g+ZYk93XLDyWZGVl+SZLnk/zieMqWJK3EkqGfZBNwO3AVsBO4NsnOkW7XAc9V1aXAbcCtI8tvAz63+nIlSavR50j/CmC+qo5V1YvAvcCekT57gLu76fuBK5MEIMkHgGPAkfGULElaqT6hvx14amj+eNe2YJ+qOg2cAi5I8qeBfwx8fPWlSpJWq0/oZ4G26tnn48BtVfX8WZ8guT7JXJK5EydO9ChJkrQSm3v0OQ5cPDR/EfD0In2OJ9kMbAVOAu8Ark7ySeA1wA+SvFBVnxpeuaruAO4A2LVr1+gORZI0Jn1C/zBwWZIdwO8Be4G/NdJnFtgHPAxcDTxUVQW860yHJDcDz48GviRpcpYc3unG6G8ADgKPA5+pqiNJbkny/q7bnQzG8OeBjwGvuKxT0hryW7XUU58jfarqAHBgpO2moekXgGuW+Bk3r6A+SX2dCf6bT023Dq1rfiJXkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+pmJm/wPTLkFqkqGvdcsdgzR+hr4kNcTQl6SGGPrSRuDXIWpMDH1JaoihL0kNMfQlqSGGviQ1xNCXpIb0Cv0ku5McTTKfZP8Cy7ckua9bfijJTNd+RZJHu8dXk3xwvOVLkpZjydBPsgm4HbgK2Alcm2TnSLfrgOeq6lLgNuDWrv3rwK6qeguwG/jXSTaPq3hJ0vL0OdK/ApivqmNV9SJwL7BnpM8e4O5u+n7gyiSpqj+qqtNd+6uAGkfRkqSV6RP624GnhuaPd20L9ulC/hRwAUCSdyQ5AjwGfHRoJyBJmrA+oZ8F2kaP2BftU1WHqupy4C8DNyZ51SueILk+yVySuRMnTvQoSZK0En1C/zhw8dD8RcDTi/Xpxuy3AieHO1TV48AfAn9p9Amq6o6q2lVVu7Zt29a/ep3zvNOmNF59Qv8wcFmSHUnOA/YCsyN9ZoF93fTVwENVVd06mwGSvB54E/DkWCqXJC3bklfSVNXpJDcAB4FNwF1VdSTJLcBcVc0CdwL3JJlncIS/t1v9ncD+JN8DfgD8g6p6di1eiCRpab0un6yqA8CBkbabhqZfAK5ZYL17gHtWWaMkaUz8RK50rvE2zDoLQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhrit1hporxrpjRdHulLUkM80pfWM++jozHzSF+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQ3qFfpLdSY4mmU+yf4HlW5Lc1y0/lGSma39PkkeSPNb9+9fGW76kBXn7Bi1iydBPsgm4HbgK2Alcm2TnSLfrgOeq6lLgNuDWrv1Z4Geq6s3APuCecRUuSVq+Pkf6VwDzVXWsql4E7gX2jPTZA9zdTd8PXJkkVfWVqnq6az8CvCrJlnEULp3zPFrXGugT+tuBp4bmj3dtC/apqtPAKeCCkT4fAr5SVd9dWamSpNXqc2vlLNBWy+mT5HIGQz7vXfAJkuuB6wEuueSSHiVJklaiz5H+ceDiofmLgKcX65NkM7AVONnNXwT8J+DvVNUTCz1BVd1RVbuqate2bduW9wokSb31Cf3DwGVJdiQ5D9gLzI70mWVwohbgauChqqokrwEeAG6sqv81rqLVlpn9D/g1i9KYLBn63Rj9DcBB4HHgM1V1JMktSd7fdbsTuCDJPPAx4MxlnTcAlwK/nOTR7vG6sb8KSVIvvb4usaoOAAdG2m4amn4BuGaB9X4F+JVV1ihJGhM/kStJDTH0Jakhhr4kNcTQ18R4BY40fYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL52r/BIWLcDQl6SGGPraMPxwl7R6hr4kNcTQl6SGGPqaCIdmlsmTsFojhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ3pFfpJdic5mmQ+yf4Flm9Jcl+3/FCSma79giS/neT5JJ8ab+naKLxGX1o/lgz9JJuA24GrgJ3AtUl2jnS7Dniuqi4FbgNu7dpfAH4Z+MWxVSxJWrE+R/pXAPNVdayqXgTuBfaM9NkD3N1N3w9cmSRV9YdV9T8ZhL8kacr6hP524Kmh+eNd24J9quo0cAq4YBwFSsMcKpJWp0/oZ4G2WkGfxZ8guT7JXJK5EydO9F1NOjeN87473sNHI/qE/nHg4qH5i4CnF+uTZDOwFTjZt4iquqOqdlXVrm3btvVdTZK0TH1C/zBwWZIdSc4D9gKzI31mgX3d9NXAQ1XV+0hfkjQZm5fqUFWnk9wAHAQ2AXdV1ZEktwBzVTUL3Anck2SewRH+3jPrJ3kSeDVwXpIPAO+tqm+M/6VIkpayZOgDVNUB4MBI201D0y8A1yyy7swq6pMkjVGv0NfkvPnuNy+r/2P7HlujSsbDq22k9cXQX2PLDfFJ/Pz1vqOQtHa89442HP96kFbO0Jekhhj6ktQQx/THZK3H7sdpsVrHPdbvMMw6c/NWuPnUtKvQlHmkL0kN8UhfWsCk/hp6hbW4V47339EQQ19ahjXbGRjMmhCHdySpIR7pSxqrqQ2NqRdDf5k20lU657KZ/Q/w5Cd+etplNG25/xfcGawPhr42LIP/3ODOYLIMfaklG+ha/XPt5oPrhSdyJakhHulrrPwUrrS+Gfra0M7sZFY6tj/1E/MbaLhF5wZDX5o2P5g1Fp4D6McxfY2NQzsbhDuZpnmkv4ip/9mvc5/hqynwSF/nhA3zV8bNWw17TZVH+tI0rIfgH66hwZPJrX4orPnQdxjnJav5T7BhjrQ1sB52OpqKXqGfZDfwL4FNwL+pqk+MLN8CfBp4O/Ad4MNV9WS37EbgOuD7wD+sqoNjq15TM3yp5HoJ/OE61t3tGQxZrRNLhn6STcDtwHuA48DhJLNV9Y2hbtcBz1XVpUn2ArcCH06yE9gLXA78BeC/J3ljVX1/3C9EkzMcrusl8EeN1jW1ncBGug7/zI5po9S7Rs721/+5MPTT50j/CmC+qo4BJLkX2AMMh/4e4OZu+n7gU0nStd9bVd8FvpVkvvt5D4+n/P4cxlm59RrsG8ZGOMp/2fi+4b+Yc+E8QJ/Q3w48NTR/HHjHYn2q6nSSU8AFXfsXR9bdvuJqezDctZA1P/LfCMG+XGd7Te4QXmYj7Qz6hH4WaKueffqsS5Lrgeu72eeTHO1R11q7EHh22kUsYsK1/fW+HTfMe5Zbx/vDF9rQl2G9vm+L1/XxVb7i1Vuv7xkM1Za/O9H36fV9OvUJ/ePAxUPzFwFPL9LneJLNwFbgZM91qao7gDv6FDwpSeaqate061jIeq1tvdYF1rYS67UusLbV6PPhrMPAZUl2JDmPwYnZ2ZE+s8C+bvpq4KGqqq59b5ItSXYAlwFfGk/pkqTlWvJIvxujvwE4yOCSzbuq6kiSW4C5qpoF7gTu6U7UnmSwY6Dr9xkGJ31PAz/rlTuSND29rtOvqgPAgZG2m4amXwCuWWTdXwV+dRU1Tsu6Gm4asV5rW691gbWtxHqtC6xtxTIYhZEktcAbrklSQ5oJ/SR3Jfn9JF8farsvyaPd48kkj3btM0n+eGjZbwyt8/YkjyWZT/Lr3YfQ1qK2tyT5Yvf8c0mu6NrTPe98kq8ledvQOvuSfLN77Fvouda4tncnOTX0vt00tM7uJEe7uvevYW0/luTh7nf0X5K8emjZjd3zH03yU2tV23LqmsK2dnGS307yeJIjSX6uaz8/yYPdtvNgktd27RPZ3lZQ18S2tbPUdk03/4Mku0bWmci2tiJV1cQD+AngbcDXF1n+z4CbuumZs/T7EvDjDC7N/hxw1VrUBvy3Mz8beB/w+aHpz3XP/1eAQ137+cCx7t/XdtOvnXBt7wb+6wI/YxPwBPAG4Dzgq8DONartMPCT3fRHgH/STe/snncLsKOrZ9Na1LbMuia9rf0I8LZu+s8Cv9u9N58E9nft+4FbJ7m9raCuiW1rZ6ntLwJvAj4P7BrqP7FtbSWPZo70q+p3GFxZ9ArdEdTfBP792X5Gkh8BXl1VD9fgt/tp4ANrVFsBZ45St/LS5xv2AJ+ugS8Cr+nq+ingwao6WVXPAQ8Cuydc22L+5FYeVfUicOZWHmtR25uA3+mmHwQ+1E3/yS1BqupbwJlbgoy9tmXWtaA13Naeqaovd9P/D3icwafk9wB3d93uHnquiWxvK6hrMWvx+1ywtqp6vKoW+iDpxLa1lWgm9JfwLuDbVfXNobYdSb6S5AtJ3tW1bWfwgbMz1vK2Ej8P/FqSp4B/Ctw4VMPobTG2n6V9krUB/HiSryb5XJLLl6h5LXwdeH83fQ0vfThw2u/bYnXBlLa1JDPAW4FDwJ+rqmdgEHLA64bqmOj71rMumMK2NlLbYqa9rZ2VoT9wLS8/yn8GuKSq3gp8DPitbgy2120lxuTvA79QVRcDv8DgsxCcpYb1UNuXgddX1Y8B/wr4z137JGv7CPCzSR5h8Kf4i0vUMKnaFqtrKttakj8DfBb4+ar6g7N1XaSONalvGXVNfFtbr+/ZcjUf+hncNuJvAPedaev+LPtON/0Ig3G4NzLYM180tPqCt5UYk33Af+ym/wODPw1h8Vtb9LrlxVrWVlV/UFXPd9MHgB9OcuEka6uq/11V762qtzPYkT/RLZrq+7ZYXdPY1pL8MIPw+ndVdeb3+O1u2ObM0NLvd+0Te9+WU9ekt7VFalvMevg/urhJn0SY5oMFTpoxGIf8wkjbNmBTN/0G4PeA87v5wwxOaJ05ufa+taiNwbjhu7vpK4FHuumf5uUn1r7UtZ8PfIvBSbXXdtPnT7i2P89Ln/24Avi/XZ2bGZzo28FLJ7AuX6PaXtf9+0MMxsE/0s1fzstPrh1jcGJtTWpbRl0T3da6n/Vp4F+MtP8aLz9h+slJbm8rqGti29pitQ0t/zwvP5E70W1t2a9n0k84rQeDo6tngO8x2ONe17X/W+CjI30/BBzpfilfBn5maNkuBuOzTwCfOrPhjbs24J3AI10Nh4C3D22At3fP/9jIxvYRBieN5oG/t1bv21lqu2Hoffsi8FeHfs77GFz18ATwS2tY2891z/O7wCeGfz/AL3XPf5ShK2HGXdty6prCtvZOBkMKXwMe7R7vY3Ar9P8BfLP798yOZyLb2wrqmti2dpbaPtj9fr8LfBs4OOltbSUPP5ErSQ1pfkxfklpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1JD/DyPCog5LHFvWAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "\n", "plt.hist(MC_Dplus_sig_dict[\"Dplus_ConsD_M\"],bins=50, density=True);\n", "plt.hist(MC_Ds_sig_dict[\"Ds_ConsD_M\"],bins=50, density=True);\n", "plt.hist(data_bkg_dict[\"Ds_ConsD_M\"],bins=50, density=True);" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#Convert data dictionaries to arrays for XG_BOOST\n", "features=return_branches_BDT(mother_index=0, l_index=l_index)\n", "MC_Dplus_sig = extract_array_for_BDT(MC_Dplus_sig_dict, features, m_plus)\n", "\n", "features=return_branches_BDT(mother_index=1, l_index=l_index)\n", "MC_Ds_sig = extract_array_for_BDT(MC_Ds_sig_dict, features, m_s)\n", "\n", "features=return_branches_BDT(mother_index=1, l_index=l_index)\n", "data_bkg = extract_array_for_BDT(data_bkg_dict, features, n)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "dim=len(features)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "#Add 0/1 label for bkg/sig\n", "\n", "MC_Dplus_sig_labelled=add_labels(MC_Dplus_sig,signal=True)\n", "MC_Ds_sig_labelled=add_labels(MC_Ds_sig,signal=True)\n", "data_bkg_labelled=add_labels(data_bkg,signal=False)\n", "\n", "#Merge MC sig and data bkg, shuffle it\n", "\n", "data=np.concatenate((data_bkg_labelled,MC_Dplus_sig_labelled), axis =0)\n", "data=np.concatenate((data,MC_Ds_sig_labelled), axis =0)\n", "np.random.seed(1)\n", "np.random.shuffle(data)\n", "\n", "#get train size\n", "train_size=data.shape[0]\n", "\n", "#Strip away the label column and convert it to a one-hot encoding\n", "\n", "X=data[:,0:dim]\n", "Y_labels=data[:,dim].astype(int)\n", "Y_labels=Y_labels.reshape(train_size,1)\n", "Y_labels_hot = to_one_hot(Y_labels)\n", "Y_labels=Y_labels_hot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training XGBOOST with K-folding" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pAUC from XG Boost 0.89248732815\n", "pAUC from XG Boost 0.8981822648\n", "pAUC from XG Boost 0.891673744237\n", "pAUC from XG Boost 0.901611464589\n", "pAUC from XG Boost 0.907353115727\n", "pAUC from XG Boost 0.895278819119\n", "pAUC from XG Boost 0.894574348419\n", "pAUC from XG Boost 0.897631671969\n", "pAUC from XG Boost 0.891227438988\n", "pAUC from XG Boost 0.875042311412\n" ] }, { "data": { "text/plain": [ "<Figure size 864x432 with 0 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#Divide the dataset k \"equi populated\" sets\n", "test=1\n", "k=10 #number of subsets\n", "plot=True\n", "for i in range(k): \n", " X_train, Y_train, X_test, Y_test, X_dict, Y_dict = k_subsets(i, k, X, Y_labels)\n", " \n", " PATH_BDTs=PATH+'BDT/'+l_flv[l_index]+'/test_'+str(test)\n", " \n", " if not os.path.exists(PATH_BDTs):\n", " os.mkdir(PATH_BDTs)\n", " \n", " if not os.path.exists(PATH_BDTs+'/variables_used.pickle'):\n", " with open(PATH_BDTs+'/variables_used.pickle', 'wb') as handle: \n", " pickle.dump(features, handle, protocol=2)\n", " \n", " PATH_current_BDT=PATH_BDTs+'/XG_'+str(i)\n", " \n", " if not os.path.exists(PATH_current_BDT):\n", " os.mkdir(PATH_current_BDT)\n", " \n", " model = XGBClassifier()\n", " model.fit(X_train, Y_train[:,1])\n", " \n", " output_XG = model.predict_proba(X_test)\n", " \n", " if plot:\n", " \n", " plt.clf();\n", " true_positives_XG=output_XG[:,1][np.where(Y_test[:,1]==1)]\n", " false_positives_XG=output_XG[:,1][np.where(Y_test[:,0]==1)]\n", " plt.hist(true_positives_XG,alpha=0.5,bins=80,density=True,label=\"True positives\");\n", " plt.hist(false_positives_XG,alpha=0.5,bins=80,density=True, label=\"False positives\");\n", " plt.legend()\n", " plt.xlabel(\"XGBoost BDT output\", fontsize='15')\n", " plt.ylabel(\"Events (a.u.)\", fontsize='15')\n", " fig=plt.gcf()\n", " fig.set_size_inches(16,8)\n", " plt.savefig(PATH_current_BDT+'/tp_vs_fp_XG.png', format='png', dpi=100)\n", " plt.clf();\n", " \n", " threshold_range=np.linspace(0.0,1.,num=30)\n", " sig_eps_vals_XG=[sel_eff(true_positives_XG,threshold_range[k]) for k in range(len(threshold_range))]\n", " bkg_eps_vals_XG=[sel_eff(false_positives_XG,threshold_range[k]) for k in range(len(threshold_range))]\n", " \n", " \n", " plt.plot(threshold_range,threshold_range, 'black', linestyle='dashed')\n", " plt.plot(bkg_eps_vals_XG,sig_eps_vals_XG,'b',label=\"XG Boost ROC Curve\")\n", " plt.xlabel(\"Background selection efficiency\", fontsize='15')\n", " plt.ylabel(\"Signal selection efficiency\", fontsize='15')\n", " pAUC_XG=roc_auc_score(Y_test,output_XG)\n", " plt.text(0.69,0.1,\"\\n XGBoost AUC {0:.4g}\\n\".format(pAUC_XG), bbox=dict(boxstyle=\"round\", facecolor='blue', alpha=0.10), horizontalalignment='center', verticalalignment='center',fontsize='15')\n", " plt.legend()\n", " fig=plt.gcf()\n", " fig.set_size_inches(8,8)\n", " \n", " print(\"pAUC from XG Boost {0}\".format(pAUC_XG))\n", " plt.savefig(PATH_current_BDT+'/roc_comparison_'+str(i)+'.png', format='png', dpi=100)\n", " plt.clf();\n", " \n", " plt.bar(np.arange(dim),model.feature_importances_)\n", " plt.xticks(np.arange(dim), features[:-1], rotation=90, fontsize=12);\n", " fig=plt.gcf()\n", " fig.set_size_inches(12,6)\n", " plt.savefig(PATH_current_BDT+'/significant_features_'+str(i)+'.png', format='png', dpi=100)\n", " plt.clf();\n", " \n", " # save XGBOOST model to file\n", " pickle.dump(model, open(PATH_current_BDT+\"/XG_\"+str(i)+\".pickle.dat\", \"wb\"), protocol=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15" } }, "nbformat": 4, "nbformat_minor": 2 }