Newer
Older
Rphipi_new / 3_BDT_train.ipynb
@Davide Lancierini Davide Lancierini on 28 May 2019 16 KB first commit
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hep/davide/miniconda3/envs/root_env/lib/python2.7/site-packages/root_numpy/_tree.py:5: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility\n",
      "  from . import _librootnumpy\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import pickle\n",
    "import math\n",
    "\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "\n",
    "from xgboost import XGBClassifier\n",
    "from tools.data_processing import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IMPORTING THE DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "l_index = 1\n",
    "mother_ID=['Dplus','Ds']\n",
    "l_flv = ['e','mu']\n",
    "PATH='/disk/lhcb_data/davide/Rphipi_new/'\n",
    "n_cats = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bkg data amounts to 9660 while signal MC amounts to 5521 Ds and 9109 Dplus samples\n"
     ]
    }
   ],
   "source": [
    "MC_Dplus_sig_dict, MC_Ds_sig_dict, data_bkg_dict = load_datasets(l_index, PATH)\n",
    "MC_Dplus_sig_dict, MC_Ds_sig_dict, data_bkg_dict = norm_chi2(MC_Dplus_sig_dict, MC_Ds_sig_dict, data_bkg_dict)\n",
    "\n",
    "m_plus=MC_Dplus_sig_dict[\"Dplus_ConsD_M\"].shape[0]\n",
    "m_s=MC_Ds_sig_dict[\"Ds_ConsD_M\"].shape[0]\n",
    "n=data_bkg_dict[\"Ds_ConsD_M\"].shape[0]\n",
    "\n",
    "#Number of input features\n",
    "\n",
    "print('Bkg data amounts to {0} while signal MC amounts to {1} Ds and {2} Dplus samples'.format(n,m_s,m_plus))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFNhJREFUeJzt3X+s3fV93/Hnq3ZxtB9xAjhbZyDXESSaWdr88Mg6JW00lNSka5wssJhNmrcgoWxFaxtVm1FVRFgrhXQbWxekDg00QtdBRrbpbjjy2GhSbSKOLwkJcZibi8OGC0pNjNzRlhAn7/1xvi6Hw72+33vvuefc68/zIR35+/18P9973ufcr1/f7/18v+d7UlVIktrwQ9MuQJI0OYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGbp13AqAsvvLBmZmamXYYkbSiPPPLIs1W1bal+6y70Z2ZmmJubm3YZkrShJPk/ffo5vCNJDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9LUuzOx/YNolSE0w9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGvrSR3bx12hVogzH0NXXegkGanF6hn2R3kqNJ5pPsX2D5liT3dcsPJZkZWvajSR5OciTJY0leNb7yJUnLsWToJ9kE3A5cBewErk2yc6TbdcBzVXUpcBtwa7fuZuA3gY9W1eXAu4Hvja16nVNm9j/gUb+0xvoc6V8BzFfVsap6EbgX2DPSZw9wdzd9P3BlkgDvBb5WVV8FqKrvVNX3x1O6JGm5+oT+duCpofnjXduCfarqNHAKuAB4I1BJDib5cpJ/tNATJLk+yVySuRMnTiz3NUiSeuoT+lmgrXr22Qy8E/jb3b8fTHLlKzpW3VFVu6pq17Zt23qUJElaiT6hfxy4eGj+IuDpxfp04/hbgZNd+xeq6tmq+iPgAPC21RYtSVqZPqF/GLgsyY4k5wF7gdmRPrPAvm76auChqirgIPCjSf5UtzP4SeAb4yldkrRcm5fqUFWnk9zAIMA3AXdV1ZEktwBzVTUL3Anck2SewRH+3m7d55L8cwY7jgIOVJWXZ0jSlCwZ+gBVdYDB0Mxw201D0y8A1yyy7m8yuGxTkjRlfiJXkhpi6EtSQwx9SWpIrzF9SeuId9bUKnikL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0NVV+Ebo0WYa+JDXE0Jekhhj6ktQQQ1/rjuP80tox9KWNzlstaxkMfUlqiKEvSQ0x9CWpIYa+JDWkV+gn2Z3kaJL5JPsXWL4lyX3d8kNJZrr2mSR/nOTR7vEb4y1faownbbVKS34xepJNwO3Ae4DjwOEks1X1jaFu1wHPVdWlSfYCtwIf7pY9UVVvGXPdkqQV6HOkfwUwX1XHqupF4F5gz0ifPcDd3fT9wJVJMr4yJUnj0Cf0twNPDc0f79oW7FNVp4FTwAXdsh1JvpLkC0netdATJLk+yVySuRMnTizrBUiS+usT+gsdsVfPPs8Al1TVW4GPAb+V5NWv6Fh1R1Xtqqpd27Zt61GSJGkl+oT+ceDiofmLgKcX65NkM7AVOFlV362q7wBU1SPAE8AbV1u0JGll+oT+YeCyJDuSnAfsBWZH+swC+7rpq4GHqqqSbOtOBJPkDcBlwLHxlC5JWq4lr96pqtNJbgAOApuAu6rqSJJbgLmqmgXuBO5JMg+cZLBjAPgJ4JYkp4HvAx+tqpNr8UIkSUtbMvQBquoAcGCk7aah6ReAaxZY77PAZ1dZoyRpTPxEriQ1xNCXpIYY+pLUEENfkhpi6Gtq/FpEafIMfUlqiKEvSQ0x9LUuOfQjrQ1DX5IaYuhLUkMMfUlqiKEvbRR+P67GwNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1JBeoZ9kd5KjSeaT7F9g+ZYk93XLDyWZGVl+SZLnk/zieMqWJK3EkqGfZBNwO3AVsBO4NsnOkW7XAc9V1aXAbcCtI8tvAz63+nIlSavR50j/CmC+qo5V1YvAvcCekT57gLu76fuBK5MEIMkHgGPAkfGULElaqT6hvx14amj+eNe2YJ+qOg2cAi5I8qeBfwx8fPWlSpJWq0/oZ4G26tnn48BtVfX8WZ8guT7JXJK5EydO9ChJkrQSm3v0OQ5cPDR/EfD0In2OJ9kMbAVOAu8Ark7ySeA1wA+SvFBVnxpeuaruAO4A2LVr1+gORZI0Jn1C/zBwWZIdwO8Be4G/NdJnFtgHPAxcDTxUVQW860yHJDcDz48GviRpcpYc3unG6G8ADgKPA5+pqiNJbkny/q7bnQzG8OeBjwGvuKxT0hryW7XUU58jfarqAHBgpO2moekXgGuW+Bk3r6A+SX2dCf6bT023Dq1rfiJXkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+pmJm/wPTLkFqkqGvdcsdgzR+hr4kNcTQl6SGGPrSRuDXIWpMDH1JaoihL0kNMfQlqSGGviQ1xNCXpIb0Cv0ku5McTTKfZP8Cy7ckua9bfijJTNd+RZJHu8dXk3xwvOVLkpZjydBPsgm4HbgK2Alcm2TnSLfrgOeq6lLgNuDWrv3rwK6qeguwG/jXSTaPq3hJ0vL0OdK/ApivqmNV9SJwL7BnpM8e4O5u+n7gyiSpqj+qqtNd+6uAGkfRkqSV6RP624GnhuaPd20L9ulC/hRwAUCSdyQ5AjwGfHRoJyBJmrA+oZ8F2kaP2BftU1WHqupy4C8DNyZ51SueILk+yVySuRMnTvQoSZK0En1C/zhw8dD8RcDTi/Xpxuy3AieHO1TV48AfAn9p9Amq6o6q2lVVu7Zt29a/ep3zvNOmNF59Qv8wcFmSHUnOA/YCsyN9ZoF93fTVwENVVd06mwGSvB54E/DkWCqXJC3bklfSVNXpJDcAB4FNwF1VdSTJLcBcVc0CdwL3JJlncIS/t1v9ncD+JN8DfgD8g6p6di1eiCRpab0un6yqA8CBkbabhqZfAK5ZYL17gHtWWaMkaUz8RK50rvE2zDoLQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhrit1hporxrpjRdHulLUkM80pfWM++jozHzSF+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQ3qFfpLdSY4mmU+yf4HlW5Lc1y0/lGSma39PkkeSPNb9+9fGW76kBXn7Bi1iydBPsgm4HbgK2Alcm2TnSLfrgOeq6lLgNuDWrv1Z4Geq6s3APuCecRUuSVq+Pkf6VwDzVXWsql4E7gX2jPTZA9zdTd8PXJkkVfWVqnq6az8CvCrJlnEULp3zPFrXGugT+tuBp4bmj3dtC/apqtPAKeCCkT4fAr5SVd9dWamSpNXqc2vlLNBWy+mT5HIGQz7vXfAJkuuB6wEuueSSHiVJklaiz5H+ceDiofmLgKcX65NkM7AVONnNXwT8J+DvVNUTCz1BVd1RVbuqate2bduW9wokSb31Cf3DwGVJdiQ5D9gLzI70mWVwohbgauChqqokrwEeAG6sqv81rqLVlpn9D/g1i9KYLBn63Rj9DcBB4HHgM1V1JMktSd7fdbsTuCDJPPAx4MxlnTcAlwK/nOTR7vG6sb8KSVIvvb4usaoOAAdG2m4amn4BuGaB9X4F+JVV1ihJGhM/kStJDTH0Jakhhr4kNcTQ18R4BY40fYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL52r/BIWLcDQl6SGGPraMPxwl7R6hr4kNcTQl6SGGPqaCIdmlsmTsFojhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ3pFfpJdic5mmQ+yf4Flm9Jcl+3/FCSma79giS/neT5JJ8ab+naKLxGX1o/lgz9JJuA24GrgJ3AtUl2jnS7Dniuqi4FbgNu7dpfAH4Z+MWxVSxJWrE+R/pXAPNVdayqXgTuBfaM9NkD3N1N3w9cmSRV9YdV9T8ZhL8kacr6hP524Kmh+eNd24J9quo0cAq4YBwFSsMcKpJWp0/oZ4G2WkGfxZ8guT7JXJK5EydO9F1NOjeN87473sNHI/qE/nHg4qH5i4CnF+uTZDOwFTjZt4iquqOqdlXVrm3btvVdTZK0TH1C/zBwWZIdSc4D9gKzI31mgX3d9NXAQ1XV+0hfkjQZm5fqUFWnk9wAHAQ2AXdV1ZEktwBzVTUL3Anck2SewRH+3jPrJ3kSeDVwXpIPAO+tqm+M/6VIkpayZOgDVNUB4MBI201D0y8A1yyy7swq6pMkjVGv0NfkvPnuNy+r/2P7HlujSsbDq22k9cXQX2PLDfFJ/Pz1vqOQtHa89442HP96kFbO0Jekhhj6ktQQx/THZK3H7sdpsVrHPdbvMMw6c/NWuPnUtKvQlHmkL0kN8UhfWsCk/hp6hbW4V47339EQQ19ahjXbGRjMmhCHdySpIR7pSxqrqQ2NqRdDf5k20lU657KZ/Q/w5Cd+etplNG25/xfcGawPhr42LIP/3ODOYLIMfaklG+ha/XPt5oPrhSdyJakhHulrrPwUrrS+Gfra0M7sZFY6tj/1E/MbaLhF5wZDX5o2P5g1Fp4D6McxfY2NQzsbhDuZpnmkv4ip/9mvc5/hqynwSF/nhA3zV8bNWw17TZVH+tI0rIfgH66hwZPJrX4orPnQdxjnJav5T7BhjrQ1sB52OpqKXqGfZDfwL4FNwL+pqk+MLN8CfBp4O/Ad4MNV9WS37EbgOuD7wD+sqoNjq15TM3yp5HoJ/OE61t3tGQxZrRNLhn6STcDtwHuA48DhJLNV9Y2hbtcBz1XVpUn2ArcCH06yE9gLXA78BeC/J3ljVX1/3C9EkzMcrusl8EeN1jW1ncBGug7/zI5po9S7Rs721/+5MPTT50j/CmC+qo4BJLkX2AMMh/4e4OZu+n7gU0nStd9bVd8FvpVkvvt5D4+n/P4cxlm59RrsG8ZGOMp/2fi+4b+Yc+E8QJ/Q3w48NTR/HHjHYn2q6nSSU8AFXfsXR9bdvuJqezDctZA1P/LfCMG+XGd7Te4QXmYj7Qz6hH4WaKueffqsS5Lrgeu72eeTHO1R11q7EHh22kUsYsK1/fW+HTfMe5Zbx/vDF9rQl2G9vm+L1/XxVb7i1Vuv7xkM1Za/O9H36fV9OvUJ/ePAxUPzFwFPL9LneJLNwFbgZM91qao7gDv6FDwpSeaqate061jIeq1tvdYF1rYS67UusLbV6PPhrMPAZUl2JDmPwYnZ2ZE+s8C+bvpq4KGqqq59b5ItSXYAlwFfGk/pkqTlWvJIvxujvwE4yOCSzbuq6kiSW4C5qpoF7gTu6U7UnmSwY6Dr9xkGJ31PAz/rlTuSND29rtOvqgPAgZG2m4amXwCuWWTdXwV+dRU1Tsu6Gm4asV5rW691gbWtxHqtC6xtxTIYhZEktcAbrklSQ5oJ/SR3Jfn9JF8farsvyaPd48kkj3btM0n+eGjZbwyt8/YkjyWZT/Lr3YfQ1qK2tyT5Yvf8c0mu6NrTPe98kq8ledvQOvuSfLN77Fvouda4tncnOTX0vt00tM7uJEe7uvevYW0/luTh7nf0X5K8emjZjd3zH03yU2tV23LqmsK2dnGS307yeJIjSX6uaz8/yYPdtvNgktd27RPZ3lZQ18S2tbPUdk03/4Mku0bWmci2tiJV1cQD+AngbcDXF1n+z4CbuumZs/T7EvDjDC7N/hxw1VrUBvy3Mz8beB/w+aHpz3XP/1eAQ137+cCx7t/XdtOvnXBt7wb+6wI/YxPwBPAG4Dzgq8DONartMPCT3fRHgH/STe/snncLsKOrZ9Na1LbMuia9rf0I8LZu+s8Cv9u9N58E9nft+4FbJ7m9raCuiW1rZ6ntLwJvAj4P7BrqP7FtbSWPZo70q+p3GFxZ9ArdEdTfBP792X5Gkh8BXl1VD9fgt/tp4ANrVFsBZ45St/LS5xv2AJ+ugS8Cr+nq+ingwao6WVXPAQ8Cuydc22L+5FYeVfUicOZWHmtR25uA3+mmHwQ+1E3/yS1BqupbwJlbgoy9tmXWtaA13Naeqaovd9P/D3icwafk9wB3d93uHnquiWxvK6hrMWvx+1ywtqp6vKoW+iDpxLa1lWgm9JfwLuDbVfXNobYdSb6S5AtJ3tW1bWfwgbMz1vK2Ej8P/FqSp4B/Ctw4VMPobTG2n6V9krUB/HiSryb5XJLLl6h5LXwdeH83fQ0vfThw2u/bYnXBlLa1JDPAW4FDwJ+rqmdgEHLA64bqmOj71rMumMK2NlLbYqa9rZ2VoT9wLS8/yn8GuKSq3gp8DPitbgy2120lxuTvA79QVRcDv8DgsxCcpYb1UNuXgddX1Y8B/wr4z137JGv7CPCzSR5h8Kf4i0vUMKnaFqtrKttakj8DfBb4+ar6g7N1XaSONalvGXVNfFtbr+/ZcjUf+hncNuJvAPedaev+LPtON/0Ig3G4NzLYM180tPqCt5UYk33Af+ym/wODPw1h8Vtb9LrlxVrWVlV/UFXPd9MHgB9OcuEka6uq/11V762qtzPYkT/RLZrq+7ZYXdPY1pL8MIPw+ndVdeb3+O1u2ObM0NLvd+0Te9+WU9ekt7VFalvMevg/urhJn0SY5oMFTpoxGIf8wkjbNmBTN/0G4PeA87v5wwxOaJ05ufa+taiNwbjhu7vpK4FHuumf5uUn1r7UtZ8PfIvBSbXXdtPnT7i2P89Ln/24Avi/XZ2bGZzo28FLJ7AuX6PaXtf9+0MMxsE/0s1fzstPrh1jcGJtTWpbRl0T3da6n/Vp4F+MtP8aLz9h+slJbm8rqGti29pitQ0t/zwvP5E70W1t2a9n0k84rQeDo6tngO8x2ONe17X/W+CjI30/BBzpfilfBn5maNkuBuOzTwCfOrPhjbs24J3AI10Nh4C3D22At3fP/9jIxvYRBieN5oG/t1bv21lqu2Hoffsi8FeHfs77GFz18ATwS2tY2891z/O7wCeGfz/AL3XPf5ShK2HGXdty6prCtvZOBkMKXwMe7R7vY3Ar9P8BfLP798yOZyLb2wrqmti2dpbaPtj9fr8LfBs4OOltbSUPP5ErSQ1pfkxfklpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1JD/DyPCog5LHFvWAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "plt.hist(MC_Dplus_sig_dict[\"Dplus_ConsD_M\"],bins=50, density=True);\n",
    "plt.hist(MC_Ds_sig_dict[\"Ds_ConsD_M\"],bins=50, density=True);\n",
    "plt.hist(data_bkg_dict[\"Ds_ConsD_M\"],bins=50, density=True);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Convert data dictionaries to arrays for XG_BOOST\n",
    "features=return_branches_BDT(mother_index=0, l_index=l_index)\n",
    "MC_Dplus_sig = extract_array_for_BDT(MC_Dplus_sig_dict, features, m_plus)\n",
    "\n",
    "features=return_branches_BDT(mother_index=1, l_index=l_index)\n",
    "MC_Ds_sig = extract_array_for_BDT(MC_Ds_sig_dict, features, m_s)\n",
    "\n",
    "features=return_branches_BDT(mother_index=1, l_index=l_index)\n",
    "data_bkg = extract_array_for_BDT(data_bkg_dict, features, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim=len(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Add 0/1 label for bkg/sig\n",
    "\n",
    "MC_Dplus_sig_labelled=add_labels(MC_Dplus_sig,signal=True)\n",
    "MC_Ds_sig_labelled=add_labels(MC_Ds_sig,signal=True)\n",
    "data_bkg_labelled=add_labels(data_bkg,signal=False)\n",
    "\n",
    "#Merge MC sig and data bkg, shuffle it\n",
    "\n",
    "data=np.concatenate((data_bkg_labelled,MC_Dplus_sig_labelled), axis =0)\n",
    "data=np.concatenate((data,MC_Ds_sig_labelled), axis =0)\n",
    "np.random.seed(1)\n",
    "np.random.shuffle(data)\n",
    "\n",
    "#get train size\n",
    "train_size=data.shape[0]\n",
    "\n",
    "#Strip away the label column and convert it to a one-hot encoding\n",
    "\n",
    "X=data[:,0:dim]\n",
    "Y_labels=data[:,dim].astype(int)\n",
    "Y_labels=Y_labels.reshape(train_size,1)\n",
    "Y_labels_hot = to_one_hot(Y_labels)\n",
    "Y_labels=Y_labels_hot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training XGBOOST with K-folding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pAUC from XG Boost 0.89248732815\n",
      "pAUC from XG Boost 0.8981822648\n",
      "pAUC from XG Boost 0.891673744237\n",
      "pAUC from XG Boost 0.901611464589\n",
      "pAUC from XG Boost 0.907353115727\n",
      "pAUC from XG Boost 0.895278819119\n",
      "pAUC from XG Boost 0.894574348419\n",
      "pAUC from XG Boost 0.897631671969\n",
      "pAUC from XG Boost 0.891227438988\n",
      "pAUC from XG Boost 0.875042311412\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 864x432 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#Divide the dataset k \"equi populated\" sets\n",
    "test=1\n",
    "k=10 #number of subsets\n",
    "plot=True\n",
    "for i in range(k):    \n",
    "    X_train, Y_train, X_test, Y_test, X_dict, Y_dict = k_subsets(i, k, X, Y_labels)\n",
    "    \n",
    "    PATH_BDTs=PATH+'BDT/'+l_flv[l_index]+'/test_'+str(test)\n",
    "    \n",
    "    if not os.path.exists(PATH_BDTs):\n",
    "        os.mkdir(PATH_BDTs)\n",
    "    \n",
    "    if not os.path.exists(PATH_BDTs+'/variables_used.pickle'):\n",
    "        with open(PATH_BDTs+'/variables_used.pickle', 'wb') as handle:  \n",
    "            pickle.dump(features, handle, protocol=2)\n",
    "    \n",
    "    PATH_current_BDT=PATH_BDTs+'/XG_'+str(i)\n",
    "    \n",
    "    if not os.path.exists(PATH_current_BDT):\n",
    "        os.mkdir(PATH_current_BDT)\n",
    "        \n",
    "    model = XGBClassifier()\n",
    "    model.fit(X_train, Y_train[:,1])\n",
    "    \n",
    "    output_XG = model.predict_proba(X_test)\n",
    "    \n",
    "    if plot:\n",
    "        \n",
    "        plt.clf();\n",
    "        true_positives_XG=output_XG[:,1][np.where(Y_test[:,1]==1)]\n",
    "        false_positives_XG=output_XG[:,1][np.where(Y_test[:,0]==1)]\n",
    "        plt.hist(true_positives_XG,alpha=0.5,bins=80,density=True,label=\"True positives\");\n",
    "        plt.hist(false_positives_XG,alpha=0.5,bins=80,density=True, label=\"False positives\");\n",
    "        plt.legend()\n",
    "        plt.xlabel(\"XGBoost BDT output\", fontsize='15')\n",
    "        plt.ylabel(\"Events (a.u.)\", fontsize='15')\n",
    "        fig=plt.gcf()\n",
    "        fig.set_size_inches(16,8)\n",
    "        plt.savefig(PATH_current_BDT+'/tp_vs_fp_XG.png', format='png', dpi=100)\n",
    "        plt.clf();\n",
    "        \n",
    "        threshold_range=np.linspace(0.0,1.,num=30)\n",
    "        sig_eps_vals_XG=[sel_eff(true_positives_XG,threshold_range[k]) for k in range(len(threshold_range))]\n",
    "        bkg_eps_vals_XG=[sel_eff(false_positives_XG,threshold_range[k]) for k in range(len(threshold_range))]\n",
    "        \n",
    "        \n",
    "        plt.plot(threshold_range,threshold_range, 'black', linestyle='dashed')\n",
    "        plt.plot(bkg_eps_vals_XG,sig_eps_vals_XG,'b',label=\"XG Boost ROC Curve\")\n",
    "        plt.xlabel(\"Background selection efficiency\", fontsize='15')\n",
    "        plt.ylabel(\"Signal selection efficiency\", fontsize='15')\n",
    "        pAUC_XG=roc_auc_score(Y_test,output_XG)\n",
    "        plt.text(0.69,0.1,\"\\n XGBoost AUC {0:.4g}\\n\".format(pAUC_XG), bbox=dict(boxstyle=\"round\", facecolor='blue', alpha=0.10), horizontalalignment='center', verticalalignment='center',fontsize='15')\n",
    "        plt.legend()\n",
    "        fig=plt.gcf()\n",
    "        fig.set_size_inches(8,8)\n",
    "        \n",
    "        print(\"pAUC from XG Boost {0}\".format(pAUC_XG))\n",
    "        plt.savefig(PATH_current_BDT+'/roc_comparison_'+str(i)+'.png', format='png', dpi=100)\n",
    "        plt.clf();\n",
    "        \n",
    "        plt.bar(np.arange(dim),model.feature_importances_)\n",
    "        plt.xticks(np.arange(dim), features[:-1], rotation=90, fontsize=12);\n",
    "        fig=plt.gcf()\n",
    "        fig.set_size_inches(12,6)\n",
    "        plt.savefig(PATH_current_BDT+'/significant_features_'+str(i)+'.png', format='png', dpi=100)\n",
    "        plt.clf();\n",
    "        \n",
    "    # save XGBOOST model to file\n",
    "    pickle.dump(model, open(PATH_current_BDT+\"/XG_\"+str(i)+\".pickle.dat\", \"wb\"), protocol=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}