diff --git a/DNN.ipynb b/DNN.ipynb index 9f07f7e..63baf44 100644 --- a/DNN.ipynb +++ b/DNN.ipynb @@ -16,29 +16,48 @@ "trunc_normal= tf.truncated_normal_initializer(stddev=1)\n", "normal = tf.random_normal_initializer(stddev=1)\n", "\n", + "from architectures.data_processing import *\n", "from architectures.utils.toolbox import *\n", "from architectures.DNN import *" ] }, { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# IMPORTING THE DATASET" + ] + }, + { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "with open('/disk/lhcb_data/davide/Rphipi/NN_test/MC_for_NN.pickle', 'rb') as f:\n", - " MC_sig_dict=pickle.load(f, encoding='latin1')\n", - "with open('/disk/lhcb_data/davide/Rphipi/NN_test/data_for_NN.pickle', 'rb') as f:\n", - " data_bkg_dict=pickle.load(f, encoding='latin1')" + "l_index=1\n", + "mag_index=1\n", + "Ds_mass= 1968" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Signal MC amounts to 23821 while bkg data amounts to 86051\n" + ] + } + ], "source": [ - "Ds_mass= 1968" + "MC_sig_dict, data_bkg_dict = load_datasets(l_index, mag_index)\n", + "m=MC_sig_dict[\"Ds_ConsD_M\"].shape[0]\n", + "n=data_bkg_dict[\"Ds_ConsD_M\"].shape[0]\n", + "\n", + "print('Signal MC amounts to {0} while bkg data amounts to {1}'.format(m,n))" ] }, { @@ -47,7 +66,28 @@ "metadata": {}, "outputs": [], "source": [ - "m=MC_sig_dict[\"Ds_ConsD_M\"].shape[0]" + "#Normalising the Chi2 vertex fits to the NDoF\n", + "\n", + "MC_sig_dict[\"Ds_ENDVERTEX_CHI2\"]=MC_sig_dict[\"Ds_ENDVERTEX_CHI2\"]/MC_sig_dict[\"Ds_ENDVERTEX_NDOF\"]\n", + "MC_sig_dict[\"Ds_OWNPV_CHI2\"]=MC_sig_dict[\"Ds_OWNPV_CHI2\"]/MC_sig_dict[\"Ds_OWNPV_NDOF\"]\n", + "MC_sig_dict[\"Ds_IPCHI2_OWNPV\"]=MC_sig_dict[\"Ds_IPCHI2_OWNPV\"]/MC_sig_dict[\"Ds_ENDVERTEX_NDOF\"]\n", + "\n", + "del MC_sig_dict[\"Ds_ENDVERTEX_NDOF\"]\n", + "del MC_sig_dict[\"Ds_OWNPV_NDOF\"]\n", + "\n", + "data_bkg_dict[\"Ds_ENDVERTEX_CHI2\"]=data_bkg_dict[\"Ds_ENDVERTEX_CHI2\"]/data_bkg_dict[\"Ds_ENDVERTEX_NDOF\"]\n", + "data_bkg_dict[\"Ds_OWNPV_CHI2\"]=data_bkg_dict[\"Ds_OWNPV_CHI2\"]/data_bkg_dict[\"Ds_OWNPV_NDOF\"]\n", + "data_bkg_dict[\"Ds_IPCHI2_OWNPV\"]=data_bkg_dict[\"Ds_IPCHI2_OWNPV\"]/data_bkg_dict[\"Ds_ENDVERTEX_NDOF\"]\n", + "\n", + "del data_bkg_dict[\"Ds_ENDVERTEX_NDOF\"]\n", + "del data_bkg_dict[\"Ds_OWNPV_NDOF\"]\n", + "\n", + "data_bkg_dict[\"phi_ENDVERTEX_CHI2\"]=data_bkg_dict[\"phi_ENDVERTEX_CHI2\"]/data_bkg_dict[\"phi_ENDVERTEX_NDOF\"]\n", + "data_bkg_dict[\"phi_OWNPV_CHI2\"]=data_bkg_dict[\"phi_OWNPV_CHI2\"]/data_bkg_dict[\"phi_OWNPV_NDOF\"]\n", + "data_bkg_dict[\"phi_IPCHI2_OWNPV\"]=data_bkg_dict[\"phi_IPCHI2_OWNPV\"]/data_bkg_dict[\"phi_ENDVERTEX_NDOF\"]\n", + "\n", + "del data_bkg_dict[\"phi_ENDVERTEX_NDOF\"]\n", + "del data_bkg_dict[\"phi_OWNPV_NDOF\"]" ] }, { @@ -56,19 +96,27 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "MC_sig_dict[\"Ds_OWNPV_FIT\"]=MC_sig_dict[\"Ds_OWNPV_CHI2\"]/MC_sig_dict[\"Ds_OWNPV_NDOF\"]\n", - "MC_sig_dict[\"Ds_ENDV_FIT\"]=MC_sig_dict[\"Ds_ENDVERTEX_CHI2\"]/MC_sig_dict[\"Ds_ENDVERTEX_NDOF\"]\n", - "MC_sig_dict[\"Ds_ConsD_M_norm\"]=MC_sig_dict[\"Ds_ConsD_M\"]/Ds_mass\n", - "#MC_sig_dict[\"Ds_Hlt2Phys_TOS_int\"]=MC_sig_dict[\"Ds_Hlt2Phys_TOS\"].astype(np.float32)\n", - "#MC_sig_dict[\"Ds_Hlt1TrackMVADecision_TOS_int\"]=MC_sig_dict[\"Ds_Hlt1TrackMVADecision_TOS\"].astype(np.float32)\n", - "#MC_sig_dict[\"Ds_Hlt2RareCharmD2PiMuMuOSDecision_TOS_int\"]=MC_sig_dict[\"Ds_Hlt2RareCharmD2PiMuMuOSDecision_TOS\"].astype(np.float32)\n", - "\n", - "del MC_sig_dict[\"Ds_ConsD_M\"], MC_sig_dict[\"Ds_ENDVERTEX_CHI2\"], MC_sig_dict[\"Ds_ENDVERTEX_NDOF\"]\n", - "del MC_sig_dict[\"Ds_OWNPV_CHI2\"], MC_sig_dict[\"Ds_OWNPV_NDOF\"]\n", - "del MC_sig_dict[\"Ds_Hlt2Phys_TOS\"], MC_sig_dict[\"Ds_Hlt1TrackMVADecision_TOS\"], MC_sig_dict[\"Ds_Hlt2RareCharmD2PiMuMuOSDecision_TOS\"]\n", - "del MC_sig_dict[\"Ds_IPCHI2_OWNPV\"] #temporary\n", - "del MC_sig_dict[\"mu_plus_MC15TuneV1_ProbNNmu\"] #temporary\n" + "branches_needed = [\n", + " \"Ds_ENDVERTEX_CHI2\",\n", + " #\"Ds_ENDVERTEX_NDOF\",\n", + " \"Ds_OWNPV_CHI2\",\n", + " #\"Ds_OWNPV_NDOF\",\n", + " \"Ds_IPCHI2_OWNPV\",\n", + " \"Ds_IP_OWNPV\",\n", + " \"Ds_DIRA_OWNPV\",\n", + " #l_flv[l_index]+\"_plus_MC15TuneV1_ProbNN\"+l_flv[l_index],\n", + " #\"Ds_Hlt1TrackMVADecision_TOS\",\n", + " #\"Ds_Hlt2RareCharmD2Pi\"+l_flv[l_index].capitalize()+l_flv[l_index].capitalize()+\"OSDecision_TOS\",\n", + " #\"Ds_Hlt2Phys_TOS\",\n", + " \"phi_ENDVERTEX_CHI2\",\n", + " #\"phi_ENDVERTEX_NDOF\",\n", + " \"phi_OWNPV_CHI2\",\n", + " #\"phi_OWNPV_NDOF\",\n", + " \"phi_IPCHI2_OWNPV\",\n", + " \"phi_IP_OWNPV\",\n", + " \"phi_DIRA_OWNPV\",\n", + " \"Ds_ConsD_M\",\n", + " ] " ] }, { @@ -77,19 +125,9 @@ "metadata": {}, "outputs": [], "source": [ + "#Number of input features\n", "\n", - "data_bkg_dict[\"Ds_OWNPV_FIT\"]=data_bkg_dict[\"Ds_OWNPV_CHI2\"]/data_bkg_dict[\"Ds_OWNPV_NDOF\"]\n", - "data_bkg_dict[\"Ds_ENDV_FIT\"]=data_bkg_dict[\"Ds_ENDVERTEX_CHI2\"]/data_bkg_dict[\"Ds_ENDVERTEX_NDOF\"]\n", - "data_bkg_dict[\"Ds_ConsD_M_norm\"]=data_bkg_dict[\"Ds_ConsD_M\"]/Ds_mass\n", - "#data_bkg_dict[\"Ds_Hlt2Phys_TOS_int\"]=data_bkg_dict[\"Ds_Hlt2Phys_TOS\"].astype(np.float32)\n", - "#data_bkg_dict[\"Ds_Hlt1TrackMVADecision_TOS_int\"]=data_bkg_dict[\"Ds_Hlt1TrackMVADecision_TOS\"].astype(np.float32)\n", - "#data_bkg_dict[\"Ds_Hlt2RareCharmD2PiMuMuOSDecision_TOS_int\"]=data_bkg_dict[\"Ds_Hlt2RareCharmD2PiMuMuOSDecision_TOS\"].astype(np.float32)\n", - "\n", - "del data_bkg_dict[\"Ds_ConsD_M\"], data_bkg_dict[\"Ds_ENDVERTEX_CHI2\"], data_bkg_dict[\"Ds_ENDVERTEX_NDOF\"]\n", - "del data_bkg_dict[\"Ds_OWNPV_CHI2\"], data_bkg_dict[\"Ds_OWNPV_NDOF\"]\n", - "del data_bkg_dict[\"Ds_Hlt2Phys_TOS\"], data_bkg_dict[\"Ds_Hlt1TrackMVADecision_TOS\"], data_bkg_dict[\"Ds_Hlt2RareCharmD2PiMuMuOSDecision_TOS\"]\n", - "del data_bkg_dict[\"Ds_IPCHI2_OWNPV\"] #temporary\n", - "del data_bkg_dict[\"mu_plus_MC15TuneV1_ProbNNmu\"] #temporary\n" + "dim=len(branches_needed)" ] }, { @@ -98,106 +136,61 @@ "metadata": {}, "outputs": [], "source": [ - "dim=len(MC_sig_dict)" + "#Convert data dictionaries to arrays for NN\n", + "\n", + "MC_sig = extract_array(MC_sig_dict, branches_needed, dim, m)\n", + "data_bkg = extract_array(data_bkg_dict, branches_needed, dim, n)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Ds_DIRA_OWNPV\n", - "Ds_IP_OWNPV\n", - "Ds_OWNPV_FIT\n", - "Ds_ENDV_FIT\n", - "Ds_ConsD_M_norm\n" - ] - }, - { - "data": { - "text/plain": [ - "(23875, 5)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "MC_sig = np.array(\n", - " \n", - " [np.zeros(shape=(dim), dtype=np.float32) for i in range(m)]\n", - " \n", - " )\n", + "#Add 0/1 label for bkg/sig\n", "\n", - "for event in range(m):\n", - " for i, key in enumerate(MC_sig_dict):\n", - " MC_sig[event][i]=MC_sig_dict[key][event]\n", - " \n", - "for key in MC_sig_dict:\n", - " print(key)\n", - "\n", - "MC_sig.shape" + "MC_sig_labelled=add_labels(MC_sig,signal=True)\n", + "data_bkg_labelled=add_labels(data_bkg,signal=False)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Ds_DIRA_OWNPV\n", - "Ds_IP_OWNPV\n", - "Ds_OWNPV_FIT\n", - "Ds_ENDV_FIT\n", - "Ds_ConsD_M_norm\n" - ] - }, - { - "data": { - "text/plain": [ - "(23875, 5)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "data_bkg = np.array(\n", - " \n", - " [np.zeros(shape=(dim), dtype=np.float32) for i in range(m)]\n", - " \n", - " )\n", - "\n", - "for event in range(m):\n", - " for i, key in enumerate(data_bkg_dict):\n", - " data_bkg[event][i]=data_bkg_dict[key][event]\n", - " \n", - "for key in data_bkg_dict:\n", - " print(key)\n", - "data_bkg.shape" + "#SOME CROSS CHECKS\n", + "#MC_sig.shape==data_bkg.shape\n", + "#MC_sig_labelled.shape[1]==dim+1==data_bkg_labelled.shape[1]\n", + "#data_bkg_labelled[:,dim].sum()==0\n", + "#(MC_sig_labelled[:,dim].sum()/m)==1" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(101872, 4000, 4000)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "val_size=1000\n", - "test_size=3000\n", + "#Establish train/val/test sizes\n", "\n", - "train_size=m-val_size-test_size" + "val_size=4000\n", + "test_size=4000\n", + "\n", + "train_size=MC_sig.shape[0]+data_bkg.shape[0]-val_size-test_size\n", + "(train_size, val_size, test_size)" ] }, { @@ -208,7 +201,7 @@ { "data": { "text/plain": [ - "(19875, 1000, 3000)" + "True" ] }, "execution_count": 11, @@ -217,7 +210,15 @@ } ], "source": [ - "(train_size, val_size, test_size)" + "#Merge MC sig and data bkg, shuffle it\n", + "\n", + "data=np.concatenate((MC_sig_labelled,data_bkg_labelled), axis =0)\n", + "np.random.seed(1)\n", + "np.random.shuffle(data)\n", + "\n", + "#Check that nothing is missing\n", + "\n", + "data.shape[0]==train_size+val_size+test_size" ] }, { @@ -226,42 +227,61 @@ "metadata": {}, "outputs": [], "source": [ + "#Strip away the label column and convert it to a one-hot encoding\n", "\n", - "MC_sig_label=np.concatenate((MC_sig,np.ones(shape=(MC_sig.shape[0],1))),axis=1)\n", - "data_bkg_label=np.concatenate((data_bkg,np.zeros(shape=(data_bkg.shape[0],1))),axis=1)\n" + "X=data[:,0:dim]\n", + "Y_labels=data[:,dim].astype(int)\n", + "Y_labels=Y_labels.reshape(train_size+val_size+test_size,1)\n", + "Y_labels_hot = to_one_hot(Y_labels)\n", + "Y_labels=Y_labels_hot\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([0.99996907, 0.03041327, 0.38916221, 0.28149822, 1.00150383,\n", - " 1. ])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ + "#Divide the dataset in train/val/test sets \n", "\n", - "data=np.concatenate((MC_sig_label,data_bkg_label), axis =0)\n", - "data[0]\n" + "X_train_0 = X[0:train_size]\n", + "Y_train = Y_labels[0:train_size]\n", + "\n", + "X_val_0 = X[train_size:train_size+val_size]\n", + "Y_val = Y_labels[train_size:train_size+val_size]\n", + "\n", + "X_test_0 = X[train_size+val_size:train_size+val_size+test_size]\n", + "Y_test = Y_labels[train_size+val_size:train_size+val_size+test_size]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['Ds_ENDVERTEX_CHI2',\n", + " 'Ds_OWNPV_CHI2',\n", + " 'Ds_IPCHI2_OWNPV',\n", + " 'Ds_IP_OWNPV',\n", + " 'Ds_DIRA_OWNPV',\n", + " 'phi_ENDVERTEX_CHI2',\n", + " 'phi_OWNPV_CHI2',\n", + " 'phi_IPCHI2_OWNPV',\n", + " 'phi_IP_OWNPV',\n", + " 'phi_DIRA_OWNPV',\n", + " 'Ds_ConsD_M']" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "np.random.seed(1)\n", - "np.random.shuffle(data)" + "branches_needed" ] }, { @@ -270,92 +290,109 @@ "metadata": {}, "outputs": [], "source": [ + "#Strip out the reconstructed Ds mass\n", "\n", - "X=data[:,0:dim]\n", - "Y_labels=data[:,dim].astype(int)\n" + "X_train = X_train_0[:,0:dim-1]\n", + "X_val = X_val_0[:,0:dim-1]\n", + "X_test = X_test_0[:,0:dim-1]\n", + "dim=X_train.shape[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SETTING UP THE NETWORK" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ - "Y_labels=Y_labels.reshape(Y_labels.shape[0],1)" + "task='TRAIN'\n", + "#task='TEST'\n", + "\n", + "PATH=l_flv[l_index]+'_Mag'+mag_status[mag_index]+'_test_4'" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "temp=np.zeros(shape=(m,2))\n", - "for i in range(m):\n", - " if Y_labels[i]==0:\n", - " temp[i][0]=1\n", - " else:\n", - " temp[i][1]=1\n", - "Y_labels=temp" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "X_train = X[0:train_size]\n", - "Y_train = Y_labels[0:train_size]\n", + "if task =='TRAIN' and os.path.exists(PATH+'/hyper_parameters.pkl'):\n", + " with open(PATH+'/hyper_parameters.pkl', 'rb') as f: \n", + " hyper_dict = pickle.load(f)\n", + " \n", + " m=hyper_dict[\"m\"]\n", + " test_size=hyper_dict[\"test_size\"]\n", + " val_size=hyper_dict[\"val_size\"]\n", + " LEARNING_RATE=hyper_dict[\"LEARNING_RATE\"]\n", + " BETA1=hyper_dict[\"BETA1\"]\n", + " BATCH_SIZE=hyper_dict[\"BATCH_SIZE\"]\n", + " EPOCHS=hyper_dict[\"EPOCHS\"]\n", + " VAL_PERIOD=hyper_dict[\"VAL_PERIOD\"]\n", + " SEED=hyper_dict[\"SEED\"]\n", + " sizes=hyper_dict[\"sizes\"]\n", + " LAMBD=hyper_dict[\"LAMBD\"]\n", + " PATH=hyper_dict[\"PATH\"]\n", "\n", - "X_val = X[train_size:train_size+val_size]\n", - "Y_val = Y_labels[train_size:train_size+val_size]\n", - "\n", - "X_test = X[train_size+val_size:train_size+val_size+test_size]\n", - "Y_test = Y_labels[train_size+val_size:train_size+val_size+test_size]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "# some constants\n", - "\n", - "LEARNING_RATE = 0.0001\n", - "BETA1 = 0.5\n", - "BATCH_SIZE = 64\n", - "EPOCHS = 20000\n", - "SAVE_SAMPLE_PERIOD = 2000\n", - "\n", - "#task='TRAIN'\n", - "task='TEST'\n", - "PATH='test_1'" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "global sizes\n", - "sizes = {\n", + "elif task=='TRAIN' and not os.path.exists(PATH+'/hyper_parameters.pkl'):\n", + " \n", + " \n", + " LEARNING_RATE = 0.001\n", + " BETA1 = 0.5\n", + " BATCH_SIZE = 64\n", + " EPOCHS = 20000\n", + " VAL_PERIOD = 2000\n", + " SEED=1\n", + " LAMBD=1.\n", + " \n", + " sizes = {\n", " 'dense_layers': [\n", - " #(512, 'bn', 0.8, lrelu, tf.glorot_uniform_initializer()),\n", - " #(128, 'bn', 0.5, tf.nn.relu, tf.glorot_uniform_initializer()),\n", - " (16, 'bn', 0.5, tf.nn.relu, tf.glorot_uniform_initializer()),\n", - " (8, 'bn', 0.5, tf.nn.relu, tf.glorot_uniform_initializer()),\n", - " (4, 'bn', 0.5, tf.nn.relu, tf.glorot_uniform_initializer()),\n", + " #(16, 'bn', 0.8, lrelu, tf.glorot_uniform_initializer()),\n", + " #(8, 'bn', 0.5, lrelu, tf.glorot_uniform_initializer()),\n", + " #(16, 'bn',0.8, lrelu, tf.glorot_uniform_initializer()),\n", + " (32, 'bn', 0.8, lrelu, tf.glorot_uniform_initializer()),\n", + " (16, 'bn', 0.8, lrelu, tf.glorot_uniform_initializer()),\n", + " (8, 'bn', 0.8, lrelu, tf.glorot_uniform_initializer()),\n", " ],\n", " 'n_classes':2,\n", - "}" + " }" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "if task == 'TEST' and os.path.exists(PATH+'/hyper_parameters.pkl'):\n", + " with open(PATH+'/hyper_parameters.pkl', 'rb') as f: \n", + " hyper_dict = pickle.load(f)\n", + " #for key, item in hyper_dict.items():\n", + " # print(key+':'+str(item))\n", + " \n", + " m=hyper_dict[\"m\"]\n", + " test_size=hyper_dict[\"test_size\"]\n", + " val_size=hyper_dict[\"val_size\"]\n", + " LEARNING_RATE=hyper_dict[\"LEARNING_RATE\"]\n", + " BETA1=hyper_dict[\"BETA1\"]\n", + " BATCH_SIZE=hyper_dict[\"BATCH_SIZE\"]\n", + " EPOCHS=hyper_dict[\"EPOCHS\"]\n", + " VAL_PERIOD=hyper_dict[\"VAL_PERIOD\"]\n", + " SEED=hyper_dict[\"SEED\"]\n", + " sizes=hyper_dict[\"sizes\"]\n", + " LAMBD=hyper_dict[\"LAMBD\"]\n", + " PATH=hyper_dict[\"PATH\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -363,9 +400,9 @@ " \n", " tf.reset_default_graph()\n", " nn = DNN(dim, sizes,\n", - " lr=LEARNING_RATE, beta1=BETA1,\n", + " lr=LEARNING_RATE, beta1=BETA1, lambd=LAMBD,\n", " batch_size=BATCH_SIZE, epochs=EPOCHS,\n", - " save_sample=SAVE_SAMPLE_PERIOD, path=PATH)\n", + " save_sample=VAL_PERIOD, path=PATH, seed=SEED)\n", " \n", " vars_to_train= tf.trainable_variables()\n", " \n", @@ -404,7 +441,7 @@ " print('Model restored.')\n", " \n", " nn.set_session(sess)\n", - " #nn.test(X_test, Y_test)\n", + " nn.test(X_test, Y_test)\n", " \n", " output = nn.predict(X_test)\n", " \n", @@ -413,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 27, "metadata": { "scrolled": true }, @@ -422,37 +459,70 @@ "name": "stdout", "output_type": "stream", "text": [ - "Propagation\n", - "Input for propagation (?, 5)\n", - "(?, 5)\n", - "(?, 16)\n", - "(?, 8)\n", - "(?, 4)\n", + "Input for propagation (?, 10)\n", "Logits shape (?, 2)\n", - "Propagation\n", - "Input for propagation (?, 5)\n", - "(?, 5)\n", - "(?, 16)\n", - "(?, 8)\n", - "(?, 4)\n", + "Input for propagation (?, 10)\n", "Logits shape (?, 2)\n", "\n", - " Evaluate model on test set...\n", - "INFO:tensorflow:Restoring parameters from test_1/CNN_model.ckpt\n" + " Training...\n", + "\n", + " ****** \n", + "\n", + "Training CNN for 20000 epochs with a total of 101872 samples\n", + "distributed in 1591 batches of size 64\n", + "\n", + "The learning rate set is 0.001\n", + "\n", + " ****** \n", + "\n", + "Evaluating performance on validation/train sets\n", + "At iteration 0, train cost: 0.003296, train accuracy 0.9763\n", + "validation accuracy 0.9898\n", + "Evaluating performance on validation/train sets\n", + "At iteration 2000, train cost: 0.002107, train accuracy 0.9792\n", + "validation accuracy 0.9375\n", + "Evaluating performance on validation/train sets\n", + "At iteration 4000, train cost: 0.001401, train accuracy 1\n", + "validation accuracy 1\n", + "Evaluating performance on validation/train sets\n", + "At iteration 6000, train cost: 0.003771, train accuracy 1\n", + "validation accuracy 1\n", + "Evaluating performance on validation/train sets\n", + "At iteration 8000, train cost: 0.001128, train accuracy 1\n", + "validation accuracy 1\n", + "Evaluating performance on validation/train sets\n", + "At iteration 10000, train cost: 0.0007949, train accuracy 1\n", + "validation accuracy 1\n", + "Evaluating performance on validation/train sets\n", + "At iteration 12000, train cost: 0.0005857, train accuracy 1\n", + "validation accuracy 1\n", + "Evaluating performance on validation/train sets\n", + "At iteration 14000, train cost: 0.0006453, train accuracy 1\n", + "validation accuracy 1\n", + "Evaluating performance on validation/train sets\n", + "At iteration 16000, train cost: 0.0004602, train accuracy 1\n", + "validation accuracy 1\n", + "Evaluating performance on validation/train sets\n", + "At iteration 18000, train cost: 0.0005284, train accuracy 1\n", + "validation accuracy 1\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Restoring parameters from test_1/CNN_model.ckpt\n" - ] + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEWCAYAAACXGLsWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XucV1W9//HXe2YYUFBDIY+CCnqwwvJSE9XpZGVeMDviKUu6nOh2zMpu/jrnh9UxD/08lV2Pj+MxPSfSLCXTLqQYmmnmnUG5CIYOiDKiAqLcYZiZz++PvWbYfPl+5zsgm+8A7+fjMQ/2Xnut/V3fzcz3811r7b2WIgIzM7Oe1NW6AmZm1vc5WJiZWVUOFmZmVpWDhZmZVeVgYWZmVTlYmJlZVQ4WVhOSFks6uUavvVbSkbV4bbPdlYOF7XUiYlBELKp1PQAkhaS/LeC8B0r6jaR1kp6S9KEe8krSdyS9kH4ulaTc8eMlzZS0Pv17fO7YOyXdKWmVpMU7+31Y3+FgYXsUSfW1rkMXSQ01fPnLgTbgYODDwBWSjqmQ91zgLOA44FjgPcCnASQ1Ar8Dfg4MBq4BfpfSAdYBk4F/KeZtWF/hYGE1J6lO0kRJC9M32xskHZg7/itJz6Vvr3fnP/QkXS3pCknTJK0D3pnSLpd0i6Q1kh6UdFSuTPe3+V7kPVXSgvTa/y3pz5I+VeF9XCzpRkk/l7Qa+JikMZLul/SSpGcl/VfXB62ku1PR2alr7JyU/h5Js1KZ+yQdu53XcyDwPuDfImJtRNwDTAX+qUKRCcD3I6I1Ip4Bvg98LB17B9AA/CgiNkXEZYCAkwAi4qGIuBboEy01K46DhfUFXyD7Zvt24FDgRbJvxl1uBUYBrwQeBn5RUv5DwCXAfsA9Ke2DwL+TfRtuSccrKZtX0hDgRuBC4CBgAfB3Vd7LuFTmFameHcCXgSHAW4B3AZ8FiIgTU5njUtfYLyW9nuyb+qfTa14JTJXUP9Xp5hREyv3cnM53NNAREY/n6jUbqNSyOCYdL5f3GGBObD0v0JwezmV7KAcL6ws+DXwtfbPdBFwMnN3VjRMRkyNiTe7YcZIOyJX/XUTcGxGdEbExpf06fettJ/vQPp7KKuV9NzAvIn6djl0GPFflvdwfEb9NddkQETMj4oGIaI+IxWQf/m/vofw/A1dGxIMR0RER1wCbgDena/GeiHhFhZ/3pHMMAlaVnHcVWTAtpzT/KmBQGrfY3nPZHqqWfapmXY4AfiOpM5fWARws6Tmyb/rvB4YCXXmGsOVDbEmZc+Y/1NeTfehVUinvoflzR0RIau35rWxdF0lHAz8AmoB9yf7mZvZQ/ghggqTP59IaU116ay2wf0na/sCaXubfH1ib3u/2nsv2UG5ZWF+wBDi95FvygNR//iGyrp2TgQOAEamMcuWLmjr5WWB41076pj28cvaydbkC+CswKiL2B77K1nUvtQS4pORa7BsR16c63JrGN8r93JrO8TjQIGlU7rzHAfMqvOa8dLxc3nnAsfm7o8gGwSudy/ZQDhbWF/wYuETSEQCShkoal47tR9YN8wLZN/P/2IX1ugV4naSzUpfY54C/2c5z7AesBtZKejXwmZLjzwP5Zz7+BzhP0pvSLa0DJZ0haT+AiDg9jW+U+zk95VkH/BqYlMq/lSzgXluhjj8DLpA0TNKhwP8Brk7H7iJr5X1BUn9J56f0P0H3zQkDgH7Zrgbk7pSyPYiDhfUF/0l2t85tktYADwBvSsd+BjwFPAPMT8d2iYhYQdb9dSlZsBoNNJMFr976ClnraA1ZIPhlyfGLgWvSAPUHIqKZbNziv8gG+lvYcmfS9vgssA+wDLge+ExEzAOQ9LbUvdTlSuD3wFzgUbIgeSVARLSR3XzwUeAl4BPAWSkd4ERgAzANODxt37YD9bU+Tl78yKx3JNUBrcCHI+LOWtfHbFdyy8KsB5JOk/SKdOtq13jDLmvdmPUVDhZmPXsLsBBYAfwDWRfMhtpWyWzXczeUmZlV5ZaFmZlVtcc8lDdkyJAYMWJErathZrZbmTlz5oqIGFot3x4TLEaMGEFzc3Otq2FmtluR9FRv8rkbyszMqnKwMDOzqhwszMysqkKDhaSxaeGYFkkTe8h3dlqQpimXdmEqt0DSaUXW08zMelbYALey5S0vB04hmyJhhqSpETG/JN9+ZIvfPJhLGw2MJ1tg5VDgj5KOjoiOouprZmaVFdmyGAO0RMSiNOnYFLKZL0t9k2yito25tHHAlLSM45Nkk6mNKbCuZmbWgyKDxTC2XgimNaV1k3QCcFhE3MzWqpZN5c+V1Cypefny5Tun1mZmto0ig0W5BV665xZJM3j+kGzu/O0q250QcVVENEVE09ChVZ8pKWvdpnZ+cNsCHnn6xR0qb2a2NygyWLQCh+X2hwNLc/v7Aa8F7pK0mGyN4alpkLta2Z1m4+YOLvtTC3NaS5cZNjOzLkUGixnAKEkj08pZ48kWuAEgIlZFxJCIGBERI8imfT4zLf4yFRifVuYaCYwCHiqikvV1WSOmo9MTKpqZVVLY3VAR0Z6WYJwO1AOTI2KepElAc0RM7aHsPEk3kK2M1g58rqg7oepSsOj07LtmZhUVOjdUREwjW24xn3ZRhbzvKNm/BLiksMol9XLLwsysmr3+Ce7ubii3LMzMKtrrg0Vdall0umVhZlbRXh8sGlLLot3Bwsysor0+WHQPcDtYmJlVtNcHC8jGLTxmYWZWmYMF2R1RHZ21roWZWd/lYAHU1fk5CzOznjhY0NWycLAwM6vEwYJskNvBwsysMgcLsgFud0OZmVXmYIG7oczMqnGwIOuGcsvCzKwyBwvcsjAzq8bBgvRQnp+zMDOryMECP2dhZlaNgwXuhjIzq6bQYCFprKQFklokTSxz/DxJcyXNknSPpNEpfYSkDSl9lqQfF1nPOs8NZWbWo8JWypNUD1wOnAK0AjMkTY2I+bls10XEj1P+M4EfAGPTsYURcXxR9ctrqBMdHQ4WZmaVFNmyGAO0RMSiiGgDpgDj8hkiYnVudyBQk0/sOrllYWbWkyKDxTBgSW6/NaVtRdLnJC0ELgW+kDs0UtIjkv4s6W3lXkDSuZKaJTUvX758hytaXyevZ2Fm1oMig4XKpG3ziRwRl0fEUcD/Bb6ekp8FDo+IE4ALgOsk7V+m7FUR0RQRTUOHDt3hino9CzOznhUZLFqBw3L7w4GlPeSfApwFEBGbIuKFtD0TWAgcXVA9s24otyzMzCoqMljMAEZJGimpERgPTM1nkDQqt3sG8ERKH5oGyJF0JDAKWFRURT2RoJlZzwq7Gyoi2iWdD0wH6oHJETFP0iSgOSKmAudLOhnYDLwITEjFTwQmSWoHOoDzImJlUXX1cxZmZj0rLFgARMQ0YFpJ2kW57S9WKHcTcFORdcurq4NOT/dhZlaRn+DGA9xmZtU4WOABbjOzahws8AC3mVk1DhZ4gNvMrBoHC9JEgg4WZmYVOViQtSzcDWVmVpmDBVBfL9rdsjAzq8jBgtSycLAwM6vIwQI/Z2FmVo2DBdlzFn6C28ysMgcLoL4O3w1lZtYDBwvcDWVmVo2DBV3dUA4WZmaVOFjgloWZWTUOFngiQTOzahwsSBMJOliYmVVUaLCQNFbSAkktkiaWOX6epLmSZkm6R9Lo3LELU7kFkk4rsp7uhjIz61lhwSKtoX05cDowGvhgPhgk10XE6yLieOBS4Aep7GiyNbuPAcYC/921JncR/JyFmVnPimxZjAFaImJRRLQBU4Bx+QwRsTq3OxDo+no/DpgSEZsi4kmgJZ2vEPV1uGVhZtaDItfgHgYsye23Am8qzSTpc8AFQCNwUq7sAyVlh5Upey5wLsDhhx++wxX1ehZmZj0rsmWhMmnbfCJHxOURcRTwf4Gvb2fZqyKiKSKahg4dusMVra/LLoMHuc3MyisyWLQCh+X2hwNLe8g/BThrB8u+LPXpKniacjOz8ooMFjOAUZJGSmokG7Cems8gaVRu9wzgibQ9FRgvqb+kkcAo4KGiKlpXlzVkvACSmVl5hY1ZRES7pPOB6UA9MDki5kmaBDRHxFTgfEknA5uBF4EJqew8STcA84F24HMR0VFUXeuVBQuPW5iZlVfkADcRMQ2YVpJ2UW77iz2UvQS4pLjabVGfWha+I8rMrDw/wU32nAV4gNvMrBIHC3ItCwcLM7OyHCzYMsDtbigzs/IcLNgywO0pP8zMynOwYMtzFm5ZmJmV52CBB7jNzKpxsMAD3GZm1ThY4OcszMyqcbDA3VBmZtU4WOCWhZlZNQ4WbAkW7R0OFmZm5ThYkHvOwi0LM7OyHCzw3VBmZtU4WOD1LMzMqnGwIL+eRY0rYmbWRzlYAHVd0324G8rMrKxCg4WksZIWSGqRNLHM8QskzZc0R9Idko7IHeuQNCv9TC0tuzN5gNvMrGeFrZQnqR64HDgFaAVmSJoaEfNz2R4BmiJivaTPAJcC56RjGyLi+KLql+cBbjOznhXZshgDtETEoohoA6YA4/IZIuLOiFifdh8AhhdYn4q8noWZWc+KDBbDgCW5/daUVskngVtz+wMkNUt6QNJZ5QpIOjflaV6+fPkOV7Te032YmfWosG4oQGXSyn4aS/oI0AS8PZd8eEQslXQk8CdJcyNi4VYni7gKuAqgqalphz/p3Q1lZtazIlsWrcBhuf3hwNLSTJJOBr4GnBkRm7rSI2Jp+ncRcBdwQlEVrfMAt5lZj4oMFjOAUZJGSmoExgNb3dUk6QTgSrJAsSyXPlhS/7Q9BHgrkB8Y36m2tCyKegUzs91bYd1QEdEu6XxgOlAPTI6IeZImAc0RMRX4LjAI+JWyb/dPR8SZwGuAKyV1kgW0b5fcRbVTeVlVM7OeFTlmQURMA6aVpF2U2z65Qrn7gNcVWbe8+vRUXkenmxZmZuX4CW6gwd1QZmY9crAgP2bhaGFmVo6DBVtaFu2+ddbMrCwHC7xSnplZNQ4WQEO6HcotCzOz8hwsyA9we8zCzKwcBwu2dENtdjeUmVlZDhbkWxYOFmZm5ThYkBvgdrAwMyvLwQKQREOdPGZhZlaBg0VSXyffOmtmVoGDRdJQJ3dDmZlV0KtgIen9vUnbnTXU13mA28ysgt62LC7sZdpuq6FObPZMgmZmZfU4Rbmk04F3A8MkXZY7tD/QXmTFdrWGeo9ZmJlVUm09i6VAM3AmMDOXvgb4clGVqoWGujo2+24oM7OyeuyGiojZEXEN8LcRcU3angq0RMSL1U4uaaykBZJaJE0sc/wCSfMlzZF0h6QjcscmSHoi/UzYgfe2Xfq5ZWFmVlFvxyxul7S/pAOB2cBPJf2gpwKS6oHLgdOB0cAHJY0uyfYI0BQRxwI3ApemsgcC3wDeBIwBviFpcC/rukMa6utod8vCzKys3gaLAyJiNfBe4KcR8Qag7JKoOWPIWiCLIqINmAKMy2eIiDsjYn3afQAYnrZPA26PiJWpBXM7MLaXdd0h2QC3WxZmZuX0Nlg0SDoE+ABwcy/LDAOW5PZbU1olnwRu3Z6yks6V1Cypefny5b2sVnn96uto991QZmZl9TZYTAKmAwsjYoakI4EnqpRRmbSyX90lfQRoAr67PWUj4qqIaIqIpqFDh1apTs8a6v1QnplZJdXuhgIgIn4F/Cq3vwh4X5VircBhuf3hZHdXbUXSycDXgLdHxKZc2XeUlL2rN3XdUf3q6vychZlZBb19gnu4pN9IWibpeUk3SRpepdgMYJSkkZIagfFkd1Llz3sCcCVwZkQsyx2aDpwqaXAa2D41pRXGz1mYmVXW226on5J90B9KNnbw+5RWUUS0A+eTfcg/BtwQEfMkTZJ0Zsr2XWAQ8CtJsyRNTWVXAt8kCzgzgEkprTAN9XVsdjeUmVlZveqGAoZGRD44XC3pS9UKRcQ0YFpJ2kW57Yp3VEXEZGByL+v3svWrkwe4zcwq6G3LYoWkj0iqTz8fAV4osmK7mruhzMwq622w+ATZbbPPAc8CZwMfL6pStZB1Q7llYWZWTm+7ob4JTOia4iM9Yf09siCyR+jnxY/MzCrqbcvi2PxcUGmw+YRiqlQbDX4oz8ysot4Gi7r83EypZdHbVsluoV+9fDeUmVkFvf3A/z5wn6QbyZ6k/gBwSWG1qoGGOrcszMwq6e0T3D+T1AycRDYVx3sjYn6hNdvFfDeUmVllve5KSsFhjwoQef18N5SZWUW9HbPY4zX4bigzs4ocLJJs8aMgwgHDzKyUg0XSry6bFd3TlJuZbcvBImmozy6Fu6LMzLblYJH0q89aFh7kNjPbloNF0i+1LDa3O1iYmZVysEi6g4W7oczMtuFgkTQ2dAULtyzMzEoVGiwkjZW0QFKLpIlljp8o6WFJ7ZLOLjnWkVbP615Br0hdYxZtDhZmZtsobDJASfXA5cApQCswQ9LUkmlCngY+BnylzCk2RMTxRdWvVGPqhmrzmIWZ2TaKnDl2DNASEYsAJE0BxpGbMiQiFqdjNf+E3jJmUfOqmJn1OUV2Qw0DluT2W1Nabw2Q1CzpAUlnlcsg6dyUp3n58uUvp64eszAz60GRwUJl0rbnVqPDI6IJ+BDwI0lHbXOyiKsioikimoYOHbqj9QS2tCza2n03lJlZqSKDRStwWG5/OLC0t4UjYmn6dxFwFwWvzNfY4AFuM7NKigwWM4BRkkZKagTGA726q0nSYEn90/YQ4K0UPD26H8ozM6ussGAREe3A+cB04DHghoiYJ2mSpDMBJL1RUivwfuBKSfNS8dcAzZJmA3cC3y56sSWPWZiZVVboOtoRMQ2YVpJ2UW57Bln3VGm5+4DXFVm3Ut1jFg4WZmbb8BPciZ+zMDOrzMEi6WpZtL64ocY1MTPrexwskn0a6wHYf59+Na6JmVnf42CR9G9wN5SZWSUOFknXmMWm9o4a18TMrO9xsEjq6kRDndyyMDMrw8Eip7GhzsHCzKwMB4ucxoY6P2dhZlaGg0VOY71bFmZm5ThY5OzTWM/6Ng9wm5mVcrDI6e8xCzOzshwscjxmYWZWnoNFTv+Gej9nYWZWhoNFTv+GOjZtdsvCzKyUg0VOY0MdmzxmYWa2DQeLnAHuhjIzK6vQYCFprKQFklokTSxz/ERJD0tql3R2ybEJkp5IPxOKrGeXAf3q2LDZwcLMrFRhwUJSPXA5cDowGvigpNEl2Z4GPgZcV1L2QOAbwJuAMcA3JA0uqq5d9mmsZ6PHLMzMtlFky2IM0BIRiyKiDZgCjMtniIjFETEHKP2EPg24PSJWRsSLwO3A2ALrCmR3Q210y8LMbBtFBothwJLcfmtK22llJZ0rqVlS8/Lly3e4ol0G9Kv33VBmZmUUGSxUJi12ZtmIuCoimiKiaejQodtVuXL26VdPW0cnHZ29raaZ2d6hyGDRChyW2x8OLN0FZXfYgH7Z5XBXlJnZ1ooMFjOAUZJGSmoExgNTe1l2OnCqpMFpYPvUlFaoAf2ydbh9R5SZ2dYKCxYR0Q6cT/Yh/xhwQ0TMkzRJ0pkAkt4oqRV4P3ClpHmp7Ergm2QBZwYwKaUVap8ULNyyMDPbWkORJ4+IacC0krSLctszyLqYypWdDEwusn6l+nd3Q3mQ28wsz09w5wxwy8LMrCwHixx3Q5mZledgkdPVsvBqeWZmW3OwyNl/n2wIZ83G9hrXxMysb3GwyBnYmAWLdZscLMzM8hwscgb1T8GizcHCzCzPwSJnYH+3LMzMynGwyGlsqKNfvVjnAW4zs604WJQY2L/BLQszsxIOFiUGNjaw1sHCzGwrDhYlBvavZ/0md0OZmeU5WJTYt7HBd0OZmZUodCLB3dHSlzawbM2mWlfDzKxPccuihAOFmdm2HCxKfHDM4QxsrK91NczM+hQHixL3L1zBurYOr8NtZpZTaLCQNFbSAkktkiaWOd5f0i/T8QcljUjpIyRtkDQr/fy4yHrmLX5hPQBrPZmgmVm3woKFpHrgcuB0YDTwQUmjS7J9EngxIv4W+CHwndyxhRFxfPo5r6h6lrr4H7Iqbmz37bNmZl2KbFmMAVoiYlFEtAFTgHElecYB16TtG4F3SVKBdarqwEH9AVi1YXMtq2Fm1qcUGSyGAUty+60prWyeiGgHVgEHpWMjJT0i6c+S3lbuBSSdK6lZUvPy5ct3SqUH79sPgJfWO1iYmXUpMliUayGUjhpXyvMscHhEnABcAFwnaf9tMkZcFRFNEdE0dOjQl11hgMH7NgLw4vq2nXI+M7M9QZHBohU4LLc/HFhaKY+kBuAAYGVEbIqIFwAiYiawEDi6wLp2239A1rKYPu+5XfFyZma7hSKDxQxglKSRkhqB8cDUkjxTgQlp+2zgTxERkoamAXIkHQmMAhYVWNduBw7KWha3znWwMDPrUth0HxHRLul8YDpQD0yOiHmSJgHNETEV+AlwraQWYCVZQAE4EZgkqR3oAM6LiJVF1TWva7W8cccfuitezsxst1Do3FARMQ2YVpJ2UW57I/D+MuVuAm4qsm49GfXKQR7gNjPL8USCZTyxbC1PLFtb62qYmfUZnu7DzMyqcrAo422jhgCwuaOzxjUxM+sbHCzKWLR8HQDzl66ucU3MzPoGB4syvviuUQDMbn2pxjUxM+sbHCzKGDQgG/f/3788WeOamJn1DQ4WZZw6+mAAnl65vsY1MTPrGxwsymio92Uxs9p6btVGnnphXa2r0c3PWVQREdR41nQz2wu9+Vt3ALD422fUuCYZf4Wu4OTXvBKAx55dU+OalNfRGXR66Vcz20UcLCp456uzYPHuy/5S45qUd9RXpzHhpw/VuhpmtpdwsKhg/BsPr3UVqvrLEytqXQUz20s4WFRQX7dlnOKG5iU95DQz2/M5WPRgxEH7AvCvN85h5lO7ZIb0PqetvZP2MtOe3DLnWUZMvIX1be01qJWZ7WoOFj2461/e2b39vivu57Qf3g1kc0at2rCZjl00wNzRGTz0ZG2C1dFfv5XTfpS97wcXvcCzqzYA8P3bFwCw9KVsf9majVz34NM1qaPVzu9mPcPC5Z6heW/gYFHFr857S/f2gufXMGLiLYz62q0c9++3cdRXpzFi4i2MmHgLV9/7JPe2rGDq7KVMnb2UuxYsY8TEW5i3dNXLrsOP/7yQD1x5P/ct3L4xinN/1sxHJ7/8QfCFaa6sc656gJO+9+eSo1l33Wd//jBf/c1clqQHGb/zh78yYuItRFQOqNc+8FTZ5WsjghVrN/Wqbm3tnRxz0R/43axnepW/SBs3d7CpvYONmzu2+SIREfzP3Yt4oeR9nXX5vbzq67futDq0tRc/+eVpP7ybb948H4AvTpnFu75f+juxZ2hr7+Tz1z/S62AYEVW/QF7467mMmHjLdtXjD4/2jVU7Cw0WksZKWiCpRdLEMsf7S/plOv6gpBG5Yxem9AWSTiuynj1544gDufnzf18138W/n8+H//dBvnD9I3zh+kf42E9nAHDGZfd0B5QRE2/hLd+6g9/NeoYbmpcwYfJDXHbHE6zd1M6MxSu5de6z/Pvv59HW3klEdHf/dD2Y8/QL2z5R/p0//JW5rav463OrWbZmIy3Lttzqe9v857n78eVs3NzBFXct5NFnssC1bPXGbQJPRLBuU9aldOKld3Lx1Hll3+eGzR3bpC1bvZHn12wEsg9MgCvuWgjApvZONrR1dLdI8v7tt4/y6WtnAvDMSxtYlP4or7x7EU3/74/MXrLt3Fw/uedJFi1fy6b2DhavWMeqDZtZ19bBpN/PL1vftvZOlq3ZyDdvns/TL6znjseeB7Zc029Ne4yHn34RgMUr1vH1387d6g/+xXVtrN64mV/OeJrVGzfz4KIXyr4OwLEX38bffetPvPrf/sBnf5G9r7sWLOPFdW3Mf3Y1l0x7jC9OmQXA2k3tbO7oZNaSl7qv0Ya2Ldd24+at95ev2cSnrpnBqvWbWbxiHV+a8gg/vffJ7vf40vo2ps5eytFfv7Xih9sNzUt43xX3AdCybA0r17UBlA3oT7+wnikPlW8pLnh+DT+5p/pUOG3tnRVv737qhXXdgbOjM7j63ifZ1N7BSd+7a5sP00efWcXV9z5JRDBj8cqy9d3c0dmd/tXfzOXD//tA97GFy9d2t8xP+t5djE0t5cvvbOFT1zRvdZ77WlbQktaymd36Er+fvZR3ff/PvLQ+u1b/Me0xLv3DX2lZtoaPTn6o+/cdYOSF0zjqq9N4btVGbpzZWrae16druqGto9dfis77efa7tKm9g7mtW758fnTyQ93/n7uCevrm97JOnK2h/ThwCtBKtib3ByNifi7PZ4FjI+I8SeOBf4yIcySNBq4HxgCHAn8Ejo6IbT+pkqampmhubq50eKe4ec5Szr/ukUJfw2pvvwEN7Ne/gaWrNlbM87ZRQ6rejfalk0fxoz8+AcDVH39j9xeI7XHw/v15fvUmxow8kIeeXMnAxnrWtZX/M2isr6Oto5Mjhw7snjkZ4HXDDmDuMz23cAc21vPPJx7JfgP6saGtne/d9vhWxwfv249xxw/jglOP5tiLbwPgLUcexP0peH7mHUfxzle9kpXr2riheQnfeu/reNN/3MEbjhjMSa9+JY8+s4orPvKGst+qf3jOcXz5l7O3Svv4W0fwlVNfxeIX1nHGZfcAMGRQIyvWtvGaQ/bnlNEHc9kdT/DNccdw2mv/hjGXZA+w7T+ggdUbt4yjzbroFI6fdDsAZ7zuEG6Z+ywA+/VvYE36cnTkkIH884lHcm/LCm6ekx2/b+JJjL/qga2m/Dn7DcO5cWbrNvW/4sOv5w0jBnfXoZwzjj2EW9K5Sw3q38CogwfRGfDrz/wd196/mDcfdRBjf7T1bfunHXMw0+c9v035yR9r4oiDBnLU0EEVX78nkmZGRFPVfAUGi7cAF0fEaWn/QoCI+FYuz/SU535JDcBzwFBgYj5vPl+l19sVwaIn7R2dLF+7iYa6OjZu7uDdl/2F/g31ve5OMTN7OXb0Se/eBosip/sYBuTvOW0F3lQpT0S0S1oFHJTSHygpO6z0BSSdC5wLcPjhtX0uoqG+jkMO2Kd7f+7FNes5IyLoDGjv7KSxvo4I6EzTlry0vo01G9s57MB92bi5g7b2Tta1tbNyXRub2jv567OvzzdXAAAHb0lEQVSrGdi/gVf9zX7ctWA5y9ds4l/HvorHnl3NI0+/xJnHHcql0xdw48xW/nXsqzh4vwHcPv95Zre+xAWnHM2/3Din+9vvJf/4Wg4/cF/+6SeVx03OaTqMe1pW8MxLW3dTjRlxII8vW+O10G2v9om3jmTyvX1j9usiWxbvB06LiE+l/X8CxkTE53N55qU8rWl/IVnX0yTg/oj4eUr/CTAtIm6q9Hq1blmYme2OetuyKHKAuxU4LLc/HFhaKU/qhjoAWNnLsmZmtosUGSxmAKMkjZTUCIwHppbkmQpMSNtnA3+KrKkzFRif7pYaCYwCPBGSmVmNFDZmkcYgzgemA/XA5IiYJ2kS0BwRU4GfANdKaiFrUYxPZedJugGYD7QDn+vpTigzMytWYWMWu5rHLMzMtl9fGLMwM7M9hIOFmZlV5WBhZmZVOViYmVlVe8wAt6TlwFMv4xRDAC89ty1fl8p8bcrzdamsL16bIyJiaLVMe0yweLkkNffmjoC9ja9LZb425fm6VLY7Xxt3Q5mZWVUOFmZmVpWDxRZX1boCfZSvS2W+NuX5ulS2214bj1mYmVlVblmYmVlVDhZmZlbVXh8sJI2VtEBSi6SJta7PriBpsaS5kmZJak5pB0q6XdIT6d/BKV2SLkvXZ46k1+fOMyHlf0LShEqv15dJmixpmaRHc2k77VpIekO61i2prHbtO9xxFa7NxZKeSb87syS9O3fswvQ+F0g6LZde9m8sLV/wYLpmv0xLGfR5kg6TdKekxyTNk/TFlL5n/95ExF77QzZ1+kLgSKARmA2MrnW9dsH7XgwMKUm7FJiYticC30nb7wZuBQS8GXgwpR8ILEr/Dk7bg2v93nbgWpwIvB54tIhrQbYOy1tSmVuB02v9nl/mtbkY+EqZvKPT309/YGT6u6rv6W8MuAEYn7Z/DHym1u+5l9flEOD1aXs/4PH0/vfo35u9vWUxBmiJiEUR0QZMAcbVuE61Mg64Jm1fA5yVS/9ZZB4AXiHpEOA04PaIWBkRLwK3A2N3daVfroi4m2wtlbydci3Ssf0j4v7IPgF+ljtXn1fh2lQyDpgSEZsi4kmghezvq+zfWPqmfBJwYyqfv859WkQ8GxEPp+01wGPAMPbw35u9PVgMA5bk9ltT2p4ugNskzZR0bko7OCKeheyPAXhlSq90jfbka7ezrsWwtF2avrs7P3WnTO7qamH7r81BwEsR0V6SvluRNAI4AXiQPfz3Zm8PFuX6AfeGe4nfGhGvB04HPifpxB7yVrpGe+O1295rsSdeoyuAo4DjgWeB76f0ve7aSBoE3AR8KSJW95S1TNpud2329mDRChyW2x8OLK1RXXaZiFia/l0G/Iasq+D51Pwl/bssZa90jfbka7ezrkVr2i5N321FxPMR0RERncD/kP3uwPZfmxVk3TENJem7BUn9yALFLyLi1yl5j/692duDxQxgVLoro5FsDfCpNa5ToSQNlLRf1zZwKvAo2fvuuhtjAvC7tD0V+Gi6o+PNwKrUxJ4OnCppcOqKODWl7Ql2yrVIx9ZIenPqo/9o7ly7pa4Pw+QfyX53ILs24yX1lzQSGEU2SFv2byz1xd8JnJ3K569zn5b+L38CPBYRP8gd2rN/b2o9wl7rH7I7FR4nu2Pja7Wuzy54v0eS3ZEyG5jX9Z7J+pDvAJ5I/x6Y0gVcnq7PXKApd65PkA1ktgAfr/V728HrcT1Zd8pmsm90n9yZ1wJoIvtAXQj8F2nWhN3hp8K1uTa99zlkH4KH5PJ/Lb3PBeTu3qn0N5Z+Fx9K1+xXQP9av+deXpe/J+sWmgPMSj/v3tN/bzzdh5mZVbW3d0OZmVkvOFiYmVlVDhZmZlaVg4WZmVXlYGFmZlU5WJiVIem+9O8ISR/ayef+arnXMuvLfOusWQ8kvYNsltX3bEeZ+ojo6OH42ogYtDPqZ7aruGVhVoaktWnz28Db0toNX5ZUL+m7kmakyfQ+nfK/I61xcB3Zg1dI+m2arHFe14SNkr4N7JPO94v8a6UnfL8r6dG0lsE5uXPfJelGSX+V9Is+sb6B7VUaqmcx26tNJNeySB/6qyLijZL6A/dKui3lHQO8NrIpugE+ERErJe0DzJB0U0RMlHR+RBxf5rXeSzZB33HAkFTm7nTsBOAYsjmC7gXeCtyz89+uWXluWZhtn1PJ5vmZRTYt9UFk8yABPJQLFABfkDQbeIBswrhR9Ozvgesjm6jveeDPwBtz526NbAK/WcCInfJuzHrJLQuz7SPg8xGx1aSJaWxjXcn+ycBbImK9pLuAAb04dyWbctsd+G/XdjG3LMx6toZs6cwu04HPpCmqkXR0mr231AHAiylQvJpsOc0um7vKl7gbOCeNiwwlW9b0oZ3yLsxeJn87MevZHKA9dSddDfwnWRfQw2mQeTnll7z8A3CepDlks7A+kDt2FTBH0sMR8eFc+m/I1l2eTTar6b9GxHMp2JjVlG+dNTOzqtwNZWZmVTlYmJlZVQ4WZmZWlYOFmZlV5WBhZmZVOViYmVlVDhZmZlbV/wdB/0mXVx7B7wAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ - "Model restored.\n" + "Parameters trained\n", + "Model saved in path: mu_MagDown_test_4/CNN_model.ckpt\n" ] } ], @@ -492,84 +562,50 @@ }, { "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(3000, 2)" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "output.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 32, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "a=np.argmax(output, axis=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "NN_selected=X_test[a.astype(np.bool)]" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "a=[MC_sig_dict[\"Ds_ConsD_M_norm\"][i][0]*Ds_mass for i in range(m)]" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "b=[NN_selected[i][4]*Ds_mass for i in range(NN_selected.shape[0])]" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAA64AAAD8CAYAAAB3qPkTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHSJJREFUeJzt3X/wZXV93/HnS0CSaX4A8tXSXciSZJOKnRHNtyupTUPVAEKbJa1kcDrKKJ1NWuxozLRC0on5RWfTRElto+0aqJhJgjRq2SrWbFHqOCM/FkUEkbAiDRt2YJNF1ElDCr77x/185bp7v/u9d7/3e++59z4fM3fuOZ/7Off7Pp8533PO+5zP+dxUFZIkSZIkddVzph2AJEmSJElHY+IqSZIkSeo0E1dJkiRJUqeZuEqSJEmSOs3EVZIkSZLUaSaukiRJkqROM3GVJEmSJHWaiaskSZIkqdNMXCVJkiRJnXb8tAM4mlNPPbW2bNky7TAkSXPirrvu+vOqWpp2HLPMY7MkaZyGPTZ3OnHdsmULe/funXYYkqQ5keT/TDuGWeexWZI0TsMem+0qLEmSJEnqNBNXSZIkSVKnmbhKkiRJkjrNxFWSJEmS1GkmrpIkSZKkTjNxlSRJkiR1momrJEmSJKnTTFwlSZIkSZ1m4ipJkiRJ6rTjpx2ApNFsufKjR5Q9vPOiKUQiSZI0OZ4DLTbvuEqSJEmSOs3EVZIkSZLUaSaukiRJkqROM3GVJEmSJHWagzNJHTZoEAJJkiRp0XjHVZIkSZLUaSaukiRJkqROM3GVJGnOJDkuyeeSfKTNn5nk9iQPJvlAkue28hPb/L72+ZZpxi1J0mpMXCVJmj9vBu7vm/8N4Jqq2go8AVzeyi8HnqiqHwSuafUkSeocE1dJkuZIks3ARcDvtvkArwD+qFW5Hri4TW9v87TPX9nqS5LUKSaukiTNl98G/g3wzTb/POCrVfV0m98PbGrTm4BHANrnT7b63ybJjiR7k+w9ePDgRsYuSdJA/hyOJElzIsk/Ah6vqruSnLtSPKBqDfHZswVVu4BdAMvLy0d8Lknj5M8BapCh77iOY6CHJFe18geSnD/ulZEkacG9HPjJJA8DN9DrIvzbwElJVi5WbwYebdP7gdMB2uffCxyaZMCSJA1jlK7C6xroIclZwKXAi4ALgHcnOW594UuSpBVVdVVVba6qLfSOuZ+oqn8GfBJ4Tat2GXBTm97d5mmff6KqvKMqSeqcoRLXMQ30sB24oaqeqqqvAPuAbeNYCUmSdFRvA96aZB+9Z1ivbeXXAs9r5W8FrpxSfJIkHdWwz7iuDPTw3W1+6IEekqwM9LAJuK3vO/uXkSRJY1RVtwK3tumHGHCxuKr+CrhkooFJknQM1rzj2j/QQ3/xgKprDfQw1AAQjlwoSZIkSeo3TFfhcQ308K3yAct8S1XtqqrlqlpeWloaeYUkSZIkSfNlzcR1jAM97AYubaMOnwlsBe4Y25pIkiRJkubSen7H9W3ADUl+Hfgc3z7Qw++1gR4O0Ut2qar7ktwIfBF4Griiqp5Zx9+XJEmSJC2AkRLX9Q70UFVXA1ePGqQkSZIkaXGN8juukiRJkiRNnImrJEmSJKnTTFwlSZIkSZ1m4ipJkiRJ6jQTV0mSJElSp5m4SpIkSZI6zcRVkiRJktRpJq6SJEmSpE4zcZUkSZIkdZqJqyRJkiSp00xcJUmaE0m+I8kdST6f5L4kv9LK35fkK0nubq+zW3mSvCvJviT3JHnpdNdAkqTBjp92AJIkaWyeAl5RVd9IcgLw6SQfa5/966r6o8PqvxrY2l4vA97T3iVJ6hTvuEqSNCeq5xtt9oT2qqMssh14f1vuNuCkJKdtdJySJI3KxFWSpDmS5LgkdwOPA3uq6vb20dWtO/A1SU5sZZuAR/oW39/KJEnqFBNXSZLmSFU9U1VnA5uBbUn+DnAV8LeBvwucArytVc+grzi8IMmOJHuT7D148OAGRS5J0upMXCVJmkNV9VXgVuCCqjrQugM/BfxXYFurth84vW+xzcCjA75rV1UtV9Xy0tLSBkcuSdKRTFwlSZoTSZaSnNSmvxN4FfClledWkwS4GLi3LbIbeH0bXfgc4MmqOjCF0CVJOipHFZYkaX6cBlyf5Dh6F6dvrKqPJPlEkiV6XYPvBn621b8ZuBDYB/wl8IYpxCxJ0ppMXCVJmhNVdQ/wkgHlr1ilfgFXbHRckiStl12FJUmSJEmdZuIqSZIkSeo0E1dJkiRJUqeZuEqSJEmSOs3EVZIkSZLUaSaukiRJkqROM3GVJEmSJHWaiaskSZIkqdNMXCVJkiRJnWbiKkmSJEnqNBNXSZIkSVKnmbhKkiRJkjrNxFWSJEmS1GkmrpIkSZKkTjt+2gFIWr8tV350YPnDOy+acCSSJEnS+HnHVZIkSZLUaSaukiTNiSTfkeSOJJ9Pcl+SX2nlZya5PcmDST6Q5Lmt/MQ2v699vmWa8UuStBoTV0mS5sdTwCuq6sXA2cAFSc4BfgO4pqq2Ak8Al7f6lwNPVNUPAte0epIkdY6JqyRJc6J6vtFmT2ivAl4B/FErvx64uE1vb/O0z1+ZJBMKV5KkoTk4k9QBqw2uJEmjSnIccBfwg8DvAF8GvlpVT7cq+4FNbXoT8AhAVT2d5EngecCfH/adO4AdAGecccZGr4IkSUdY847rOJ+XSXJVK38gyfkbtVKSJC2qqnqmqs4GNgPbgBcOqtbeB91drSMKqnZV1XJVLS8tLY0vWEmShjRMV+GxPC+T5CzgUuBFwAXAu9tVYUmSNGZV9VXgVuAc4KQkK72sNgOPtun9wOkA7fPvBQ5NNlJJkta2ZuI6xudltgM3VNVTVfUVYB+9K8GSJGkMkiwlOalNfyfwKuB+4JPAa1q1y4Cb2vTuNk/7/BNVdcQdV0mSpm2owZmSHJfkbuBxYA8jPC8DrDwv863yAcv0/60dSfYm2Xvw4MHR10iSpMV1GvDJJPcAdwJ7quojwNuAtybZR++YfG2rfy3wvFb+VuDKKcQsSdKahhqcqaqeAc5uV3E/zLE9LzP0czTALoDl5WWv+kqSNKSqugd4yYDyhxjQy6mq/gq4ZAKhSZK0LiP9HM46n5f5VvmAZSRJkiRJGmiYUYXH9bzMbuDSNurwmcBW4I5xrYgkSZIkaT4N01X4NOD6NgLwc4Abq+ojSb4I3JDk14HP8e3Py/xee17mEL2RhKmq+5LcCHwReBq4onVBliRJkiRpVWsmruN8XqaqrgauHj1MSZIkSdKiGukZV0mSJEmSJs3EVZIkSZLUaSaukiRJkqROM3GVJEmSJHWaiaskSZIkqdNMXCVJkiRJnWbiKkmSJEnqNBNXSZIkSVKnmbhKkiRJkjrNxFWSJEmS1GkmrpIkSZKkTjNxlSRpTiQ5Pcknk9yf5L4kb27lv5zkz5Lc3V4X9i1zVZJ9SR5Icv70opckaXXHTzsASZI0Nk8DP19Vn03y3cBdSfa0z66pqt/qr5zkLOBS4EXA3wL+V5IfqqpnJhq1JElr8I6rJElzoqoOVNVn2/TXgfuBTUdZZDtwQ1U9VVVfAfYB2zY+UkmSRmPiKknSHEqyBXgJcHsrelOSe5Jcl+TkVrYJeKRvsf0cPdGVJGkqTFwlSZozSb4L+CDwlqr6GvAe4AeAs4EDwDtWqg5YvAZ8344ke5PsPXjw4AZFLUnS6kxcJUmaI0lOoJe0/n5VfQigqh6rqmeq6pvAe3m2O/B+4PS+xTcDjx7+nVW1q6qWq2p5aWlpY1dAkqQBTFwlSZoTSQJcC9xfVe/sKz+tr9pPAfe26d3ApUlOTHImsBW4Y1LxSpI0LEcVliRpfrwceB3whSR3t7JfAF6b5Gx63YAfBn4GoKruS3Ij8EV6IxJf4YjCkqQuMnGVJGlOVNWnGfzc6s1HWeZq4OoNC0qSpDGwq7AkSZIkqdNMXCVJkiRJnWbiKkmSJEnqNBNXSZIkSVKnmbhKkiRJkjrNxFWSJEmS1GkmrpIkSZKkTjNxlSRJkiR1momrJEmSJKnTTFwlSZIkSZ1m4ipJkiRJ6jQTV0mSJElSp5m4SpIkSZI6zcRVkiRJktRpJq6SJEmSpE4zcZUkaU4kOT3JJ5Pcn+S+JG9u5ack2ZPkwfZ+citPkncl2ZfkniQvne4aSJI0mImrJEnz42ng56vqhcA5wBVJzgKuBG6pqq3ALW0e4NXA1vbaAbxn8iFLkrQ2E1dJkuZEVR2oqs+26a8D9wObgO3A9a3a9cDFbXo78P7quQ04KclpEw5bkqQ1rZm4jrPbUZLLWv0Hk1y2caslSdJiS7IFeAlwO/CCqjoAveQWeH6rtgl4pG+x/a1MkqROGeaO61i6HSU5BXg78DJgG/D2lWRXkiSNT5LvAj4IvKWqvna0qgPKasD37UiyN8negwcPjitMSZKGtmbiOsZuR+cDe6rqUFU9AewBLhjr2kiStOCSnEAvaf39qvpQK35spQtwe3+8le8HTu9bfDPw6OHfWVW7qmq5qpaXlpY2LnhJklYx0jOu6+x2ZHckSZI2UJIA1wL3V9U7+z7aDaw8onMZcFNf+evbYz7nAE+uHNslSeqS44eteHi3o96xcXDVAWV1lPLD/84Oel2MOeOMM4YNT5IkwcuB1wFfSHJ3K/sFYCdwY5LLgT8FLmmf3QxcCOwD/hJ4w2TDlSRpOEMlrkfrdlRVB4bsdrQfOPew8lsP/1tVtQvYBbC8vHxEYitJkgarqk8z+EIxwCsH1C/gig0NSpKkMRhmVOFxdTv6OHBekpPboEzntTJJkiRJklY1zB3XsXQ7qqpDSX4NuLPV+9WqOjSWtZAkSZIkza01E9dxdjuqquuA60YJUJIkSZK02IYenEmSJEmSumTLlR89ouzhnRdNIRJtNBNXSZIkSVMxKPGUBhnpd1wlSZIkSZo0E1dJkiRJUqeZuEqSJEmSOs3EVZIkSZLUaSaukiRJkqROc1RhacIcPU+SJEkajYmrNMf8bTNJkiTNA7sKS5IkSZI6zcRVkiRJktRpJq6SJEmSpE4zcZUkSZIkdZqJqyRJcyLJdUkeT3JvX9kvJ/mzJHe314V9n12VZF+SB5KcP52oJUlam4mrJEnz433ABQPKr6mqs9vrZoAkZwGXAi9qy7w7yXETi1SSpBGYuEqSNCeq6lPAoSGrbwduqKqnquorwD5g24YFJ0nSOpi4SpI0/96U5J7WlfjkVrYJeKSvzv5WJklS55i4SpI0394D/ABwNnAAeEcrz4C6NegLkuxIsjfJ3oMHD25MlJIkHYWJqyRJc6yqHquqZ6rqm8B7ebY78H7g9L6qm4FHV/mOXVW1XFXLS0tLGxuwJEkDmLhKkjTHkpzWN/tTwMqIw7uBS5OcmORMYCtwx6TjkyRpGMdPOwBJkjQeSf4QOBc4Ncl+4O3AuUnOptcN+GHgZwCq6r4kNwJfBJ4GrqiqZ6YRtyRJazFxlSRpTlTVawcUX3uU+lcDV29cRJIkjYddhSVJkiRJnWbiKkmSJEnqNBNXSZIkSVKnmbhKkiRJkjrNxFWSJEmS1GkmrpIkSZKkTjNxlSRJkiR1momrJEmSJKnTTFwlSZIkSZ1m4ipJkiRJ6jQTV0mSJElSp5m4SpIkSZI6zcRVkiRJktRpJq6SJEmSpE4zcZUkSZIkdZqJqyRJcyTJdUkeT3JvX9kpSfYkebC9n9zKk+RdSfYluSfJS6cXuSRJqzNxlSRpvrwPuOCwsiuBW6pqK3BLmwd4NbC1vXYA75lQjJIkjWTNxHVcV26TXNbqP5jkso1ZHUmSFltVfQo4dFjxduD6Nn09cHFf+fur5zbgpCSnTSZSSZKGN8wd1/exziu3SU4B3g68DNgGvH0l2ZUkSRvuBVV1AKC9P7+VbwIe6au3v5VJktQpayauY7pyez6wp6oOVdUTwB6OTIYlSdJkZUBZHVEp2ZFkb5K9Bw8enEBYkiR9u2N9xnXUK7dDX9H14ChJ0tg9ttIFuL0/3sr3A6f31dsMPHr4wlW1q6qWq2p5aWlpw4OVJOlw4x6cabUrt0Nd0QUPjpIkbYDdwMr4EpcBN/WVv76NUXEO8OTKhWlJkrrkWBPXUa/cDnVFV5IkrU+SPwQ+A/xwkv1JLgd2Aj+R5EHgJ9o8wM3AQ8A+4L3Av5xCyJIkren4Y1xu5crtTo68cvumJDfQG4jpyao6kOTjwL/rG5DpPOCqYw9bmg1brvzotEOQtGCq6rWrfPTKAXULuGJjI5Ikaf3WTFzbldtzgVOT7Kc3OvBO4MZ2FfdPgUta9ZuBC+lduf1L4A0AVXUoya8Bd7Z6v1pVhw/4JEmSJEnSEdZMXMd15baqrgOuGyk6SZIkSdLCG/fgTJIkSZIkjZWJqyRJkiSp0451cCZJM2rQgFEP77xoCpFIkiRJw/GOqyRJkiSp00xcJUmSJEmdZuIqSZIkSeo0E1dJkiRJUqc5OJMkSZKkueFAlPPJO66SJEmSpE4zcZUkSZIkdZqJqyRJkiSp00xcJUmSJEmdZuIqSZIkSeo0E1dJkiRJUqf5cziSJC2AJA8DXweeAZ6uquUkpwAfALYADwM/XVVPTCtGSfNt0M/USMPyjqskSYvjH1bV2VW13OavBG6pqq3ALW1ekqTOMXGVJGlxbQeub9PXAxdPMRZJklZlV2FpTOz+IqnjCvjjJAX8l6raBbygqg4AVNWBJM8ftGCSHcAOgDPOOGNS8UqS9C0mrpIkLYaXV9WjLTndk+RLwy7YktxdAMvLy7VRAUqStBq7CkuStACq6tH2/jjwYWAb8FiS0wDa++PTi1CSpNV5x1XSwG7OD++8aAqRSNoISf4G8Jyq+nqbPg/4VWA3cBmws73fNL0oJUlanYmrJEnz7wXAh5NA79j/B1X1P5PcCdyY5HLgT4FLphijJEmrMnGVJGnOVdVDwIsHlP8F8MrJRyRJ0mh8xlWSJEmS1GkmrpIkSZKkTjNxlSRJkiR1ms+4SiMaNAKvJEmSpI3jHVdJkiRJUqeZuEqSJEmSOs3EVZIkSZLUaT7jKmmg1Z7lfXjnRROORJIkzZIujgfiec3sM3GVjqKLO15JkiRp0dhVWJIkSZLUad5xlTSSQXeh7WYjSZKkjWTiKkmSJGkheUF+dpi4aiH57KokSfPDgXek+WfiKmndvFopSdoI6z2+eHzaeN4M0KSYuGquuTOVJGm8Rjm2DkoS13t31GP79Nj2mqaJJ65JLgD+A3Ac8LtVtXPSMWj2uePsPq9yS7PDY7Okwy3yuZbnMN000cQ1yXHA7wA/AewH7kyyu6q+OMk4NFnD/vMv8g5yUaz3Kr2k8fPYrKNZ77G5i8f2LiYlGxGT51/j1cXtZtFM+o7rNmBfVT0EkOQGYDvgwXFM1rvzmVQ3HXeSWsuw28hq26wHGGloHpvHZJIDBE0q0Zmkaf/9Ya3n+LRR6zjs985KG8+69V408HxlsFTV5P5Y8hrggqr6523+dcDLqupNg+ovLy/X3r17x/K3/UeVNI/We2K0aAfHJHdV1fK04+iSrh2bu5jkbcSdq/WexG7ExT1Js2Xad9DHtb8e9tg86cT1EuD8ww6O26rqX/XV2QHsaLM/DDwwhj99KvDnY/ieRWF7Dc+2Go3tNTzbajTDttf3VdXSRgczS6Z4bJ6EWfw/msWYYTbjnsWYwbgnaRZjhtmLe6hj86S7Cu8HTu+b3ww82l+hqnYBu8b5R5Ps9Qr78Gyv4dlWo7G9hmdbjcb2WpepHJsnYRa3i1mMGWYz7lmMGYx7kmYxZpjduNfynAn/vTuBrUnOTPJc4FJg94RjkCRJz/LYLEnqvIneca2qp5O8Cfg4vSH3r6uq+yYZgyRJepbHZknSLJj477hW1c3AzRP+szPXvWnKbK/h2Vajsb2GZ1uNxvZahykdmydhFreLWYwZZjPuWYwZjHuSZjFmmN24j2qigzNJkiRJkjSqST/jKkmSJEnSSGY2cU1yXZLHk9zbV3Z2ktuS3J1kb5JtrfzcJE+28ruT/FLfMhckeSDJviRXTmNdNtoqbfXiJJ9J8oUk/yPJ9/R9dlVrjweSnN9XPvdtBaO1V5ItSf5v37b1n/uW+ZFWf1+SdyXJNNZnIyU5Pcknk9yf5L4kb27lpyTZk+TB9n5yK09ri31J7kny0r7vuqzVfzDJZdNap410DO21sPuuo7TVJW3+m0mWD1tmofddi2CV/fMH+v5HHk5yd99nndgmRom7K8eVVWJe7TyrM/v2EePuxD52lZg7f542Stwd2q5n8rzlGOLuxLY9dlU1ky/gHwAvBe7tK/tj4NVt+kLg1jZ9LvCRAd9xHPBl4PuB5wKfB86a9rpNqK3uBH68Tb8R+LU2fVZrhxOBM1v7HLcobXUM7bWlv95h33MH8KNAgI+tbJvz9AJOA17apr8b+JO2Df174MpWfiXwG236wtYWAc4Bbm/lpwAPtfeT2/TJ016/DrTXwu67jtJWL6T3O6K3Ast99Rd+37UIr0H758M+fwfwS13bJkaMuxPHlUExs/p5Vmf27SPG3Yl97Coxd/48bcS4u7Jdz+R5yzHE3Ylte9yvmb3jWlWfAg4dXgysXJH6Xg77HboBtgH7quqhqvpr4AZg+1gD7YBV2uqHgU+16T3AP23T24EbquqpqvoKsI9eOy1EW8HI7TVQktOA76mqz1RvT/F+4OJxxzptVXWgqj7bpr8O3A9sordtXN+qXc+z674deH/13Aac1NrqfGBPVR2qqifotfEFE1yViTiG9lrN3P8/rtZWVXV/VT0wYJGF33ctglX2z0Dvzgjw08AftqLObBMjxj3QpI8rI55ndWbfPovnh7N6njaL50uzet7i+UPPzCauq3gL8JtJHgF+C7iq77MfTfL5JB9L8qJWtgl4pK/O/la2CO4FfrJNX8KzPz6/WpssclvB6u0FcGaSzyX530l+rJVtotdGK+a+vZJsAV4C3A68oKoOQG9nCzy/VXP7aoZsL3DfdXhbrcZtSz8GPFZVD7b5WdkmDo8buntcWe08q+ttPYvnh7N6njYz50uzet6yyOcP85a4/gvg56rqdODngGtb+WeB76uqFwP/EfjvrXxQH/pFGWb5jcAVSe6i1+Xgr1v5am2yyG0Fq7fXAeCMqnoJ8FbgD9rzHAvVXkm+C/gg8Jaq+trRqg4oW7jta4T2Wvh9l9uWRvBavv2u5axsE4fH3eXjymrnWV1v61k8P5zV87SZOF+a1WPLop8/zFviehnwoTb93+jdDqeqvlZV32jTNwMnJDmV3lWG/itBm1m7+8hcqKovVdV5VfUj9A6YX24frdYmC9tWsHp7ta46f9Gm72rlP0SvvTb3fcXctleSE+jtRH+/qlb+/x5rXWlWugE93soXfvsapb0Wfd+1SlutZuG3rUWW5HjgnwAf6Cvu/DYxKO6OH1cGnmfR/baeufPDWT1Pm4XzpVk9b/H8Yf4S10eBH2/TrwAeBEjyN1dGKEtvJLnnAH9B7wHyrUnOTPJc4FJg98SjnoIkz2/vzwH+LbAyuttu4NIkJyY5E9hK76H5hW0rWL29kiwlOa5Nfz+99nqoddf4epJz2rb3euCmqQS/gdq6XQvcX1Xv7PtoN70TBdr7TX3lr0/POcCTra0+DpyX5OQ2It55rWyujNpei7zvOkpbrcZ912J7FfClqurvcjgL28QRcXf8uDLwPIvu79tn7vxwVs/Tun6+NKvnLZ4/NNWBEaKO5UXvKs4B4P/Ru3pwOfD3gbvojZB1O/Ajre6bgPta+W3A3+v7ngvpjcz1ZeAXp71eE2yrN7f1/hNgJ5C++r/Y2uMB+kZ2W4S2GrW96A06sLJtfRb4x33fs0zvWY8vA/+pv43n5dX+5wq4B7i7vS4EngfcQu/k4BbglFY/wO+0NvkC3z4q7BvpDTKxD3jDtNetI+21sPuuo7TVT7X/y6eAx4CP9y2z0PuuRXgN2j+38vcBPzugfie2iVHi7spxZVDMrH6e1Zl9+4hxd2Ifu0rMnT9PGyXuDm3XM3necgxxd2LbHvdrZWOSJEmSJKmT5q2rsCRJkiRpzpi4SpIkSZI6zcRVkiRJktRpJq6SJEmSpE4zcZUkSZIkdZqJqyRJkiSp00xcJUmSJEmdZuIqSZIkSeq0/w+U3YZ1uS8RSQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plt.subplot(1,2,1)\n", - "plt.hist(a,bins=70);\n", - "plt.subplot(1,2,2)\n", - "plt.hist(b,bins=70);\n", + "if task=='TEST':\n", "\n", - "fig=plt.gcf();\n", - "fig.set_size_inches(16,4)" + " Ds_mass_MC =[MC_sig_dict[\"Ds_ConsD_M\"][i][0] for i in range(m)]\n", + " NN_selected = X_test_0[np.argmax(output, axis=1).astype(np.bool)]\n", + " Ds_mass_sel_NN = [NN_selected[i][dim] for i in range(NN_selected.shape[0])]\n", + " Ds_mass_train_NN =[X_train_0[i][dim] for i in range(X_train_0.shape[0])]\n", + "\n", + " plt.subplot(1,2,1)\n", + " plt.hist(Ds_mass_MC,bins=70);\n", + " plt.subplot(1,2,2)\n", + " plt.hist(Ds_mass_sel_NN,alpha=0.8,bins=70);\n", + " #plt.hist(Ds_mass_train_NN,alpha=0.2,bins=70);\n", + "\n", + " fig=plt.gcf();\n", + " fig.set_size_inches(20,8)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "if task=='TRAIN':\n", + " hyper_dict={\n", + " 'm':m,\n", + " 'test_size':test_size,\n", + " 'val_size':val_size,\n", + " 'LEARNING_RATE':LEARNING_RATE,\n", + " 'BETA1':BETA1,\n", + " 'BATCH_SIZE':BATCH_SIZE,\n", + " 'EPOCHS':EPOCHS,\n", + " 'VAL_PERIOD':VAL_PERIOD,\n", + " 'SEED':SEED,\n", + " 'sizes':sizes,\n", + " 'LAMBD':LAMBD,\n", + " 'PATH':PATH,\n", + " }\n", + " with open(PATH+'/hyper_parameters.pkl', 'wb') as f: \n", + " pickle.dump(hyper_dict, f)" ] }, {