Newer
Older
HCAL_project / HCAL_bicycleGAN.ipynb
@Davide Lancierini Davide Lancierini on 2 Dec 2018 59 KB First commit
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy as sp\n",
    "import numpy as np\n",
    "import os \n",
    "import pickle\n",
    "\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "from datetime import datetime\n",
    "\n",
    "from architectures.bicycle_GAN import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "task='TRAIN'\n",
    "#task='TEST'\n",
    "\n",
    "# Option to save and restore hyperparameters\n",
    "\n",
    "PATH='HCAL_bycicleGAN_test35'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if task =='TRAIN' and os.path.exists(PATH+'/hyper_parameters.pkl'):\n",
    "    with open(PATH+'/hyper_parameters.pkl', 'rb') as f:  \n",
    "        hyper_dict = pickle.load(f)\n",
    "        for key, item in hyper_dict.items():\n",
    "            print(key+':'+str(item))\n",
    "     \n",
    "    reco_path = hyper_dict['reco_path']\n",
    "    true_path = hyper_dict['true_path']\n",
    "    n_batches = hyper_dict['n_batches']\n",
    "    test_size = hyper_dict['test_size']\n",
    "    LEARNING_RATE = hyper_dict['LEARNING_RATE']\n",
    "    BETA1 = hyper_dict['BETA1']\n",
    "    BATCH_SIZE = hyper_dict['BATCH_SIZE']\n",
    "    EPOCHS = hyper_dict['EPOCHS']\n",
    "    SAVE_SAMPLE_PERIOD = hyper_dict['SAVE_SAMPLE_PERIOD']\n",
    "    SEED = hyper_dict['SEED']\n",
    "    d_sizes = hyper_dict['d_sizes']\n",
    "    g_sizes_enc = hyper_dict['g_sizes_enc']\n",
    "    g_sizes_dec = hyper_dict['g_sizes_dec']\n",
    "    e_sizes = hyper_dict['e_sizes']\n",
    "    preprocess = hyper_dict['preprocess']\n",
    "    cost_type = hyper_dict['cost_type']\n",
    "    validating_size=hyper_dict['validating_size']\n",
    "    cycl_weight=hyper_dict['cycl_weight']\n",
    "    latent_weight=hyper_dict['latent_weight']\n",
    "    kl_weight=hyper_dict['kl_weight']\n",
    "    discr_steps=hyper_dict['discr_steps']\n",
    "    gen_steps=hyper_dict['gen_steps']\n",
    "    vae_steps=hyper_dict['vae_steps']\n",
    "    \n",
    "elif task=='TRAIN' and not os.path.exists(PATH+'/hyper_parameters.pkl'):\n",
    "    \n",
    "    reco_path = '/disk/lhcb_data/davide/HCAL_project/piplus_cells_inout/piplus/reco/'\n",
    "    true_path = '/disk/lhcb_data/davide/HCAL_project/piplus_cells_inout/piplus/true/'\n",
    "    #reco_path = '/disk/lhcb_data/davide/HCAL_project_full_event/reco/'\n",
    "    #true_path = '/disk/lhcb_data/davide/HCAL_project_full_event/true/'\n",
    "    n_batches = 1\n",
    "    test_size = 5000\n",
    "    validating_size=1000\n",
    "    \n",
    "    LEARNING_RATE = 2e-4\n",
    "    BETA1 = 0.5\n",
    "    BATCH_SIZE = 16\n",
    "    EPOCHS = 4\n",
    "    SAVE_SAMPLE_PERIOD = 200\n",
    "    SEED = 1\n",
    "    preprocess=False\n",
    "    cost_type='GAN'\n",
    "    \n",
    "    latent_weight=100\n",
    "    cycl_weight=10\n",
    "    kl_weight=1\n",
    "    discr_steps=1\n",
    "    gen_steps=4\n",
    "    vae_steps=4\n",
    "    \n",
    "    latent_dims=16\n",
    "    ndf = 16\n",
    "    ngf = 16\n",
    "    nef = 16\n",
    "    s = 2\n",
    "    f = 4\n",
    "    d = 0.8\n",
    "    \n",
    "    stddev_d=0.02\n",
    "    stddev_g=0.02\n",
    "    stddev_e=0.02\n",
    "\n",
    "\n",
    "    g_sizes_enc={\n",
    "        'latent_dims':latent_dims,\n",
    "    \n",
    "        'conv_layers': [\n",
    "                            (ngf/2, f,   s,  False, 1, lrelu, tf.truncated_normal_initializer(stddev_g)), #(batch, 52, 64, 1) =>  (batch, 26, 32, ngf)\n",
    "                            (ngf,   f,   s, 'bn',  d,  lrelu, tf.truncated_normal_initializer(stddev_g)),#(batch, 26, 32, ngf) => (batch, 13, 16, ngf*2)\n",
    "                            (ngf*2, f,   s, 'bn',  1,  lrelu, tf.truncated_normal_initializer(stddev_g)),#(batch, 13, 16, ngf*4) => (batch, 7, 8, ngf*4)\n",
    "                            (ngf*4, f,   s, 'bn',  1,  lrelu, tf.truncated_normal_initializer(stddev_g)),#(batch, 7, 8, ngf*4) => (batch, 4, 4, ngf*4)\n",
    "                            (ngf*8, f,   s, 'bn',  d,  lrelu, tf.truncated_normal_initializer(stddev_g)),#(batch, 4, 4, ngf*4) => (batch, 2, 2, ngf*4)\n",
    "                            #(ngf*8, f, s, 'bn', 1, lrelu, tf.truncated_normal_initializer(stddev_g)),#(batch, 2, 2, ngf*4) => (batch, 1, 1, ngf*4)\n",
    "                        \n",
    "                       ],\n",
    "    }\n",
    "    \n",
    "    g_sizes_dec={\n",
    "    \n",
    "         \n",
    "         'deconv_layers': [\n",
    "    \n",
    "                            (ngf*4, f,   s,  'bn',   1, tf.nn.softplus, tf.truncated_normal_initializer(stddev_g)),#(batch, 1, 1, ngf*4) => (batch, 2, 2, ngf*4*2)\n",
    "                            (ngf*2, f,   s,  'bn',   d, tf.nn.softplus, tf.truncated_normal_initializer(stddev_g)),#(batch, 2, 2, ngf*4*2) => (batch, 4, 4, ngf*4*2)\n",
    "                            (ngf,   f,   s,  'bn',   1, tf.nn.softplus, tf.truncated_normal_initializer(stddev_g)),#(batch, 4, 4, ngf*4*2) => (batch, 7, 8, ngf*4*2)\n",
    "                            (ngf/2, f,   s,  'bn',   1, tf.nn.softplus, tf.truncated_normal_initializer(stddev_g)),#(batch, 7, 8, ngf*4*2) => (batch, 13, 16, ngf*2*2)\n",
    "                            (1,     f,   s,   False, d, tf.nn.softplus, tf.truncated_normal_initializer(stddev_g)),#(batch, 26, 32, ngf*2) => (batch, 52, 64, 1)\n",
    "                       \n",
    "                         ],  \n",
    "    }\n",
    "    \n",
    "    \n",
    "    d_sizes={\n",
    "        \n",
    "         'conv_layers': [\n",
    "                             (ndf/2,  f, s,  False, 1, lrelu, tf.truncated_normal_initializer(stddev_d)), #(batch, 52, 64, 2) => (batch, 26, 32, ndf)\n",
    "                             (ndf,    f, s, 'bn',   d, lrelu, tf.truncated_normal_initializer(stddev_d)), #(batch, 26, 32, ndf) => (batch, 13, 16, ndf*2)\n",
    "                             (ndf*2,  f, s, 'bn',   1, lrelu, tf.truncated_normal_initializer(stddev_d)), #(batch, 13, 16, ndf*2) => (batch, 7, 8, ndf*4)\n",
    "                             (ndf*4,  f, s, 'bn',   d, lrelu, tf.truncated_normal_initializer(stddev_d)), #(batch, 7, 8, ndf*4) => (batch, 7, 8, ndf*8)\n",
    "                             (ndf*8,  f, 1, 'bn',   1, lrelu, tf.truncated_normal_initializer(stddev_d)), #(batch, 7, 8, ndf*8) => (batch, 7, 8, ndf*8)\n",
    "                             \n",
    "                        ],\n",
    "                        \n",
    "         \n",
    "         'dense_layers': [\n",
    "                             (ndf*32, 'bn',  d, lrelu, tf.truncated_normal_initializer(stddev_d)),\n",
    "                             #(ndf*4,  'bn',  d, lrelu, tf.truncated_normal_initializer(stddev_d)),             \n",
    "                             (ndf,   False,  d, lrelu, tf.truncated_normal_initializer(stddev_d))],\n",
    "        \n",
    "         'readout_layer_w_init':tf.truncated_normal_initializer(stddev_d),\n",
    "    }\n",
    "    \n",
    "\n",
    "    \n",
    "    e_sizes={\n",
    "        'latent_dims':latent_dims,\n",
    "        \n",
    "        #'conv_layer_0':[\n",
    "        #                    (nef,   f, s, False, 1,lrelu, tf.truncated_normal_initializer(stddev_e)),\n",
    "        #                ],\n",
    "        'conv_layers':[\n",
    "            \n",
    "                            (nef/2, f,   s,   False, 1, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                            (nef,   f,   s,   'bn',  1, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                            (nef*2, f,   s,   'bn',  d, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                            (nef*4, f,   s,   'bn',  1, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                            (nef*8, f,   s,   'bn',  d, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                        ],\n",
    "\n",
    "        \n",
    "        #'convblock_layer_0':[\n",
    "        #                     (nef*2, 1, s, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 26, 32, ndf) => (batch, 13, 16, ndf*2)\n",
    "        #                     (nef*2, f, 1, 'bn', d, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 13, 16, ndf*2) => (batch, 13, 16, ndf*2)\n",
    "        #                     (nef*2, 1, 1, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 13, 16, ndf*2) => (batch, 13, 16, ndf*2)\n",
    "        #                     \n",
    "        #                ],\n",
    "        #'convblock_shortcut_layer_0':[\n",
    "        #                    (nef*2, 1, s, False, 1, tf.random_normal_initializer(stddev_e))\n",
    "        #                ],\n",
    "        #\n",
    "        #'convblock_layer_1':[\n",
    "        #                     (nef*4, 1, s, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 26, 32, ndf) => (batch, 13, 16, ndf*2)\n",
    "        #                     (nef*4, f, 1, 'bn', d, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 13, 16, ndf*2) => (batch, 13, 16, ndf*2)\n",
    "        #                     (nef*4, 1, 1, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 13, 16, ndf*2) => (batch, 13, 16, ndf*2)\n",
    "        #                     \n",
    "        #                ],\n",
    "        #'convblock_shortcut_layer_1':[\n",
    "        #                    (nef*4, 1, s, False, 1, tf.random_normal_initializer(stddev_e))\n",
    "        #                ],\n",
    "        #'convblock_layer_2':[\n",
    "        #                     (nef*4, 1, s, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 26, 32, ndf) => (batch, 13, 16, ndf*2)\n",
    "        #                     (nef*4, f, 1, 'bn', d, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 13, 16, ndf*2) => (batch, 7, 8, ndf*4)\n",
    "        #                     (nef*4, 1, 1, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 7, 8, ndf*4) => (batch, 7, 8, ndf*8)\n",
    "        #                     \n",
    "        #                ],\n",
    "        #\n",
    "        #'convblock_shortcut_layer_2':[\n",
    "        #                    (nef*4, 1, s, False, 1, tf.random_normal_initializer(stddev_e))\n",
    "        #                ],\n",
    "        \n",
    "       # 'convblock_layer_3':[\n",
    "       #                       (nef*8, 1, s, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 26, 32, ndf) => (batch, 13, 16, ndf*2)\n",
    "       #                       (nef*8, f, 1, 'bn', 1, lrelu, tf.random_normal_initializer(stddev_e)), #(batch, 13, 16, ndf*2) => (batch, 7, 8, ndf*4)\n",
    "       #                       (nef*8, 1, 1, 'bn', 1, tf.nn.relu,tf.random_normal_initializer(stddev_e)), #(batch, 7, 8, ndf*4) => (batch, 7, 8, ndf*8)\n",
    "       #                       \n",
    "       #                  ],\n",
    "       # 'convblock_shortcut_layer_3':[\n",
    "       #                      (nef*8, 1, s, False, 1, tf.random_normal_initializer(stddev_e))\n",
    "       #                     ],\n",
    "\n",
    "        \n",
    "        'dense_layers':[    \n",
    "                            (nef*16, 'bn',   d, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                            (nef*8, 'bn',   d, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                            (nef,   False, d, lrelu, tf.random_normal_initializer(stddev_e)),\n",
    "                      \n",
    "                        ],\n",
    "        'readout_layer_w_init':tf.random_normal_initializer(stddev_e)\n",
    "        \n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if task == 'TEST' and os.path.exists(PATH+'/hyper_parameters.pkl'):\n",
    "    with open(PATH+'/hyper_parameters.pkl', 'rb') as f:  \n",
    "        hyper_dict = pickle.load(f)\n",
    "        for key, item in hyper_dict.items():\n",
    "            print(key+':'+str(item))\n",
    "     \n",
    "    reco_path = hyper_dict['reco_path']\n",
    "    true_path = hyper_dict['true_path']\n",
    "    #true_path_K = hyper_dict['true_path_K']\n",
    "    n_batches = hyper_dict['n_batches']\n",
    "    test_size = hyper_dict['test_size']\n",
    "    LEARNING_RATE = hyper_dict['LEARNING_RATE']\n",
    "    BETA1 = hyper_dict['BETA1']\n",
    "    BATCH_SIZE = hyper_dict['BATCH_SIZE']\n",
    "    EPOCHS = hyper_dict['EPOCHS']\n",
    "    SAVE_SAMPLE_PERIOD = hyper_dict['SAVE_SAMPLE_PERIOD']\n",
    "    SEED = hyper_dict['SEED']\n",
    "    d_sizes = hyper_dict['d_sizes']\n",
    "    g_sizes_enc = hyper_dict['g_sizes_enc']\n",
    "    g_sizes_dec = hyper_dict['g_sizes_dec']\n",
    "    e_sizes = hyper_dict['e_sizes']\n",
    "    preprocess = hyper_dict['preprocess']\n",
    "    cost_type = hyper_dict['cost_type']\n",
    "    validating_size=hyper_dict['validating_size']\n",
    "    cycl_weight=hyper_dict['cycl_weight']\n",
    "    latent_weight=hyper_dict['latent_weight']\n",
    "    kl_weight=hyper_dict['kl_weight']\n",
    "    discr_steps=hyper_dict['discr_steps']\n",
    "    gen_steps=hyper_dict['gen_steps']\n",
    "    vae_steps=hyper_dict['vae_steps']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim=1\n",
    "select=False\n",
    "if preprocess=='normalise':\n",
    "    train_true, test_true, min_true, max_true, train_reco, test_reco, min_reco, max_reco = load_data(true_path, reco_path, n_batches, select=select, n_cells=dim*dim, energy_fraction=1, preprocess=preprocess, test_size=test_size)\n",
    "else:\n",
    "    train_true, test_true, train_reco, test_reco = load_data(true_path, reco_path, n_batches,  select=select, n_cells=None, energy_fraction=1, preprocess=preprocess, test_size=test_size)\n",
    "    \n",
    "train_true, train_reco = delete_undetected_events_double(train_true, train_reco)\n",
    "test_true, test_reco = delete_undetected_events_double(test_true, test_reco)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfsAAACtCAYAAAC3K9aMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHJJJREFUeJzt3XncHVWd5/HPl4QQZV9kDRBkUUFbULRBXBgBAQWlp8FGESJjQ7fSioMLi+IyozYoA4wzLTYtTliNCAiIKLKLjaaNI6jsO4kJm2SBsIb8+o9zLqnn5q7P3ev5vl+v+8qtqlNVv1t57jn3nFN1jiICMzMzK69VBh2AmZmZ9ZYLezMzs5JzYW9mZlZyLuzNzMxKzoW9mZlZybmwNzMzKzkX9mZWl6TbJO3eg+NOlxSSJo9z/xMkfa/bcZmVlQt7szokPShpz6p1H5X0q6p1H5Y0R9LTkhZI+pmkt9fYLyR9sGr97pLmtRjPDfkYb6xaf2lev3th3XaSfiTpCUmLJf1B0jGSJrX48QGIiB0i4oZ29um2WtcoIr4REX8/qJjMRo0Le7MOSDoGOB34BrARsAXwHeADVUlnAE/mfztxN3BY4fzrA7sAjxfWbQ3MBuYCb4iItYGDgJ2BNTs8v5mNIBf2ZuMkaW3gfwBHRcQlEbE0Il6MiJ9ExOcK6bYE3gUcCewtaaMOTns+8HeFGvqHgB8DLxTSfBW4OSKOiYgFABFxV0R8OCIW1fgcG0i6QtIiSU9KuknSKnnby60bkr6SWwvOk/SUpD/mFoTjJT0maa6k9xSOO6ZlJO9/Xq0PJelwSXfk494v6R/y+tWBnwGb5paTpyVtWn0sSe/PXQ6LcgvI66ri+Gxu3Vgs6YeSprZ32c1Gmwt7s/HbFZhKKmwbOQyYExEXA3cAh3RwzvnA7UClUD0MOKcqzZ7ARW0c8zPAPOBVpNaJE4B642jvD5wLrAv8HriKlI9sRvrh869tnLfoMWA/YC3gcOA0SW+KiKXAvsD8iFgjv+YXd5S0HfAD4NP5M1wJ/ETSlEKyDwL7AFsBfwV8dJxxmo0kF/ZmjV2aa4uLJC0iNdFXrA88ERHLmhzjMOCC/P4COm/KPwc4TNJrgHUi4tdV29cHFrRxvBeBTYAtc8vETVF/0oybIuKq/Jl/RCpcT4qIF4FZwHRJ67T1aYCI+GlE3BfJjcAvgHe0uPvfAT+NiKtzHKcArwDeVkjz7YiYHxFPAj8Bdmw3RrNR5sLerLEDImKdygv4RGHbX4ANGt1RLmk3Um1yVl51AfAGSZ0UNpcA7wY+SaplV/sLqfBu1beAe4Ff5Cb04xqkfbTw/lnSj52XCssAa7RxbgAk7SvpN7kbYRHwXmCDFnffFHioshARy0n3K2xWSPNI4f0z44nRbJS5sDcbv18DzwEHNEgzAxBwi6RHSDfOQeEmu3ZFxDOkfuyPU7uwvwb42zaO91REfCYiXk1qpj9G0h7jja9gKfDKwvLGtRJJWg24mFQj3yj/qLqSdN2gfpdCxXxgy8LxBGwO/Hl8YZuVjwt7s3GKiMXAl4B/kXSApFdKWjXXUr+ZbwL7IOnGvB0Lr08ChxRbBCRNrXqpximLTgDeFREP1tj2ZeBtkr4laeN8/G3yjXUrNbFL2i9vF7AEeCm/OnULcHC+JjsDB9ZJNwVYjfREwTJJ+7LingRIrQnr5xsia7kQeJ+kPSStSroH4Xng5i58BrNScGFv1oGIOBU4BvgiqbCaC/wTcCmpxv8scE5EPFJ5AWcBk0g3jEFqbn626rV1k/POj4hf1dl2H+nmwenAbZIWk2rOc4CnauyyLak14GlSa8V3uvRs/Ymkz7GQ9ITABbUSRcRTwKdIhfZC4MPA5YXtd5JuwLs/3zuxadX+dwEfAf4P8ASpdWL/iCg+oWA2oan+fThmZmZWBq7Zm5mZlZwLezMzs5JzYW9mZlZyLuzNzMxKbuCFfWG866clLZf0bGG5k2FFWz3/eZJeqIrjdw3S715It1RptrHivpvW27dwjAslPSJpiaQ7JRUnNlld0sWSHsrH3qVq31UknSZpodKMZl8rbNuzKpZKfO9rEs+G+XjXVK3/iKS7tGIc9LrHkTQrn+s9Veu/m9cf3CQGSXpA0odrbDtWVTPNmVlzeV6ASp76iKSZkvo6oFA+Z3Uee2uD9O9oksdu0eR8k6rSPy3pJUmnFdKsnvOmvyjNl3B9Ydtncl60RNKfJf0v5cdkJW2c87oFeb+bJL2lQSxfy/F/omr9Z/P6L7Zw/a6V9KUa6/82x9daOR4RQ/MCHgT2bJJmcpfPeR7wlXHuu026hG3vtwMwpfD+ceD1efmVpMeQdiPNkrZL1b5HA38ijZC2BXAP8NE659mH9CjTak3iORf4JXBNYd1WpGeV9yANbvJfSYOkrFPnGLOAu4DzC+tWI41c9gBwcAvX5avAz2us/xNwxKD/Pv3ya9RexTyVNKjRrcDX+xzDTOBr49x3OmlQpXHn+6T5Fp4B3lZYNyvn/RuQHoN9c2Hb1sDa+f0GwI3Ap/LyNqQ5GDbO+32cNK/DK+uc+2s5X5xdtf7WvP6LLcR/KHBPjfWXAie3eh0GXrNvJv8y+qGkH0h6CvhIro1/pZBmT0kPFpanSfqxpMfzL7SjBhB6XRFxW6x4BjhIhemr87ZnIuLbEfHvwPIau88AvhkRCyLiYdL0qh+tc6oZwKyIeL5eLEpzoG/Bys9AbwE8GhHXRnJJjmerBh/tEmBPSZVpVPcnPbf9l6pz/kNuMXhS0k8lVYY1PQfYQ9ImhbRvIn35LmxwXjNrItIYD1dRmBdA0mqSTpH0sKRHc233FYXtH5B0S67l3idpn7x+U0mX5+/wvZKO6P8natlBwJ8j4mYASTuQJlf6x4h4IiJeioiXW3MjzdGwuLD/clIhT0TcGxGnRxoz46WIOANYnTRWRT2/BtZTmssCpaGyVyFNJPUypZkbb1UaS+JXkl6fN10CbCzpbYW065OGlK6eBKuuoS/ss78hFUZrAz9slFBp6s8rgN+SBivZC/icujP8Z0skfVlSw1nHJJ0l6VngNuA+4OoWD7896Vdhxa2k1oHq469NmlP97AYxrAp8mzQITLWbgbmS9s7NYh8kFdp3NIhtKSkzqYyUttKMbLk5/9OkHwIbkf7gz4OXB4P5DWNnhTsUuKzqy2dmbZI0jVTI3VtYfTKwHekHwDakPPNLOf1bSd/fzwHrAO8ktRRAGuRoHmleggOBb/Q5j/2CpEtbTD6DsfngXwP3A19X6gr9g6QxQ15LOjRXLh8n5a9n1oljZ1Jl7f4mMZzLiiGya+WLbwH+Dfh70kRW3wcukzQl0syPFzF2iO2DgT9ExG1NzrtCP5tzWmiueJCqZnxSM8h1VevGNL2TpvR8ML/fDbi/Kv2JwL/VOed5pPHNFxVeZ7UY77ia8Qv7TyLNc34CMKnG9icoNOMDq5JaAqYX1r0BeK7GvkcAdzY5//HAafn9P1Joxi+sewZYRhp5ba8Gx5pFGkVuT+B6UvPXAtJQqHPIzfh52yFVn+lF0pjokP7Y/5DfTyZ1A+w76L9Nv/waxVfOU5/O398AriV3xZEKqaXA1oX0uwIP5Pf/Wskfqo65OWk45TUL6/4ZmFknhpk18tizW4x/Oh0045NaTF8Ctiis+1I+5ok5f3p3vg7b1dj/NcD/BDassW1tUmXtcw3O/7X8+bfK/xersuJH0ixyMz6poP9y1b73Abvl97uTunVXy8uzgU+2cy1GpWY/t420WwJbaOy0pJ+nziQc2UlRmNksIj7WUbQtitQMdCPpl3XTc0aavvN5Uh9UxVrUHgK1+tfsGJK2JBWsX66zfX9SH/rbSF+IvYFzJW3fJMzrSF+w44BLYuUhS7cEvlv4v3mc9GNiWt5+IbBNbup6D+lL+Ysm5zSz+g6IiDVJBcZrWTGb4KtI9wj9rvB9/HleD6lQv6/G8TYFnow0zHHFQ4ydZbDaKVV5bKfTPLfqMOCGSF2eFc+Sfnx8IyJeiIjrSPcs7VW9c6ShmO8C/m9xvaTVgZ8Cv4yIbzULIiIeAB4GvgHcFhHzq5JsCRxbVW5twopreiOwGNhf0nbATqTWlZbVnZpzyFSP6dtoNq25pJsZXtfzqLpnMk3GQi+4HXgj8Ie8/EbSr8uXSdoa2IXU1FPPrqQv7d1Kc668ApgqaV5ETMvHvS4ibsnpb5Z0C+lX8O31DhoRyyX9gPQDa9caSeaSfglfXGf/Jbl57jDSH/v5sWIKVTMbp4i4UdJM0uyCB5BaDp8FdoiIWjMEzqV2vjSf1Ae9ZqHA34Ihm2VQKWM7lFRpKarkna2OFT8mf1aa4OoyUtP9J+rtVMM5pO6AQ2tsmwt8NSJOrrVjRISkSlfArcCVEfFEG+cemZp9tVtIs1ytm2/m+lRh26+BF/LjE1Nzf/MbJL15MKGOlW9sOVDp0Y/JkvYjTUd6XSHNavkPCmBK4T3kPjSlR0A2J/V/z6w6zQxSQT2vQSiXkpqWKjOxfZ3UNFS5Tr8F/kvlJpHcf7cLK74ojZxC6o6ZXWPbd4EvFm5WWVdS9XSsZ5P67Rvec2BmbTsd2EvSjhGxnNR8fJqkDQEkbSZp75z2LOBwpdkEV8nbXhsRc0n39PxzzmP/itQyef4APk8j7wA2JE0CVXQ9qXvw2JwHvxN4O7kFUdIRheuxA3AsqfsDSVNIN8wtBg6P3KbeogtIrZW1KjpnAkdJeouSNSTtn1sQKs4mPWH13xhHvjiqhf1M0o1iD5GanWZVNkTEMtJdim8l9ZE8Qep7Wqv6IAUnaOwzmY90Epykr0r6cYMkR5N+Hf+FVMh+PCKuKmx/iPSLe31S882zylOVkm6ou5b0+W8BfhQRMwvnrvyaXemPQdLHlMcQiIjnYuxMbEuAFyLi0bz9KuCbwOX5RpUfACdGxC+bff5Id7heV2fbD0hNYpdIWpI/Q3Xz2dWkfvy7IuKPzc5nZq2JiMdJFYYT86pjSTfs/SZ/H68h9VMTEf8BHA6cRircbiQ1NwN8iNSfPh/4Mam/udFNxp+vymPbqpVWk3SipJ80STYDuCjSDW4vy12L7ydVJhYDZ5DuI7onJ3kn8CdJS0k3e1/Oiuv1DtJNjvsCiwufp1Yr5hiRnrS6JiKeq7FtNukxvjNIj0vfTZrJsZjmPuA/gKmkLoS2eNY7MzOzkhvVmr2ZmZm1aCCFvaR9lAZVuVfScYOIoRlJM7TykIsNh3k0M+ulUcg7WyXpkDp5bOvPjlvL+t6MrzTozd2kftp5pBvBPhQRde/wNjOb6Jx3WicGUbN/K3BvRNyfb5SYRbpRwszM6nPeaeM2iOfsN2PsIDnzSMMX1jVFq8VUVm+UxEbccyzlhXheg47DbIj1Nu+s/vY1a/RtlL6bx+pUv47d7UbyRscubHsuWss7B1HY1wpqpcsk6UjgSICpvJK/7t+wyzYAs+PaQYdgNux6mndq8tjiIJYtG3f6bh6rU/06djeP2+zYxW2/WXYVrRhEM/480jCMFdNIz2qOERFnRsTOEbHzqqzWt+DMzIaU804bt0EU9r8FtpW0VR6N6GDSoAVmZlaf804bt74340fEMkn/RJoKdRLw/Whnmj4zswmo13lnp83QnTRpt9PM326XQC91u+m+3rG78ZkGMhFORFwJXDmIc5uZjSrnnTZeHkHPzMys5FzYm5mZldyozGdvZmZDrJ3+62Z90NXH6qRvvJf96v3S8DO0+Hy/a/ZmZmYl58LezMys5NyMb2ZmbRtU03qzEfFGcfS9ZsfuRhyu2ZuZmZWcC3szM7OSc2FvZmZWcu6zNzOzvuqkD7rRDHCtHKudoXf7NRxuN9PW45q9mZlZybmwNzMzKzkX9mZmZiXnPnszM1tJu0PatqPdfvdGsfSyX72Xz9n3m2v2ZmZmJefC3szMrORc2JuZmZWc++zNzGwl7Tyv3iz9pHXXBeClhQtrLvfieffKOSpaPVen562nk/5/j41vZmZmTblmb2ZmPVFdg6+oXm6WvlHa6mNWr69Y/q6dAFjlxt+P63zNYhjvMfrFNXszM7OSc83ezGyCamec+GqtpG+3tttJjb56uZJu4b6vAWCd25cAsLxwnGbna1bz76Q23844Bg3HHWjxv801ezMzs5JzzX5EXDX/FgD23nTHAUdiZtZf4+kbr6Sr1Oh/duUFALz5Kx8HYKMf3Vn3PJV96z09MIpc2I8IF/Jm1ql2HuHq5HGvZk3trRaetdK1W/Auv+V2APa7e18A1rvrOQCW7rYtq//7PQ3jbLa+qNn16uQxvobHjtaO4WZ8MzOzknPN3szMuqpZTbjdpvjx7Fuxyo7bA7D8bxak5YXp0bupACVonm+Va/ZmZmYl55q9mdkE0as++maa9bN34wa4esdYZfFSAJZ10ErQixv0OpnWdzx6VrOXtLmk6yXdIek2SUfn9etJulrSPfnf2sMdmZlNQM47rRd62Yy/DPhMRLwO2AU4StL2wHHAtRGxLXBtXjYzs6R0eedLCxfy0sKFTFp33brD2RZV0hVfrZ6jWixaQixaMq64K+eud+xiumHXs8I+IhZExP/P758C7gA2Az4AnJ2TnQ0c0KsYzMxGjfNO64W+9NlLmg7sBMwGNoqIBZD+qCVt2I8YzMxGTcd5p1ofErebffQV9QapqdZs8JpGNedm+9SbIKfS0tBKXM3UStfuPRDjHgNhWIbLlbQGcDHw6YhouS1F0pGS5kia8yLP9y5AM7Mh1JW8M5x3WtLTmr2kVUl/rOdHxCV59aOSNsm/TDcBHqu1b0ScCZwJsJbWa3GMIDOz0de1vHOVweSdnd5938rd7+3WymvV/CfC8/UVvbwbX8BZwB0RcWph0+XAjPx+BnBZr2IwMxs1zjutF3pZs98NOBT4o6Rb8roTgJOACyV9DHgYOKiHMZiZjZru5Z3Rm774TnVjytjx1sq70WrQSrpmz8138v8ynrHxe1bYR8SvANXZvEevzmtmNsqcd1oveAQ9MzPrqk77wlu9i388WjlWGfvyPTa+mZlZydWt2Uu6EvhERDzYv3DMzEbbUOWdbTxnP5E1eya/otkd/o0M+to3qtnPBH4h6Qv5MRAzM2tuJs47bcjUrdlHxIWSfgp8CZgj6VxgeWH7qfX2NTObqJx31q8p92JWu17oZtzDotkNei8CS4HVgDUp/MGamVldI5d3dvPRsE4Ly1rpu13gFgfV6cajgMOuUZ/9PsCppIEc3hQRz/QtKjOzEeW804ZRo5r9F4CDIuK2fgVjZlYCzjuzZhPajKfm3I1jTESN+uzf0c9AzMzKwHmnDSMPqmNmVlZtDJfby+Fdu6nbA/b0a9/q61mt2fVvlLYVHlTHzMys5FyzNzOznmh32NtGNedWB7ppppvT5o4S1+zNzMxKzjV7MzNbSS/76OvVoOvVnIvPxPdSL8/RtSltx8k1ezMzs5Jzzd7MzPqq3oQz9Z7L76Sfvd104zlGv1oeOuGavZmZWcm5Zm9mZk21+5w4NK9Fd6M23Oo4/I3u8G81jn7W3ovXsxtjHrhmb2ZmVnKu2ZuZWU80eza+UR94rfWNjKfPfpCj8fWba/ZmZmYl55q9mVlZaWx/b7+f9e60z76Xd9B3Y59G27s5t0DDcfNbPKxr9mZmZiXnmr2ZmfVEpzXzYt9+p8fo9T7DzjV7MzOzknPN3sysrNqYz76b6t2F36pad+tXtxI0G32vGzp5WqCX133MsaO1fVzYm5lZV9UrkFvVymNy/SjcyzTVrZvxzczMSs41ezMzmzAadQ2Mt8Y+nhsIu/loXit6XrOXNEnS7yVdkZe3kjRb0j2SfihpSq9jMDMbNc47rZv60Yx/NHBHYflk4LSI2BZYCHysDzGYmY2akc87X1q4sO3BbIr9/NXL3dSN41biG4U+/J4W9pKmAe8DvpeXBbwbuCgnORs4oJcxmJmNGued1m297rM/Hfg8sGZeXh9YFBGVzol5wGY9jsHMbNT0Je/s1lC67Zyn0bn6UUPu5jk6GS63349E9qxmL2k/4LGI+F1xdY2kNZ8SlHSkpDmS5rzI8z2J0cxs2DjvtF7oZc1+N+D9kt4LTAXWIv1aXUfS5PwLdRowv9bOEXEmcCbAWlqvxWEDzMxGnvNO67qe1ewj4viImBYR04GDgesi4hDgeuDAnGwGcFmvYjAzGzXOO60XBjGozrHAMZLuJfVDnTWAGMzMRk3X885YtuzlVzOaPHnMq520xfN0+1yD1M0Yqz9ztz9/X65iRNwA3JDf3w+8tR/nNTMbZc47rVs8XK6ZmVnJubA3MzMrueHtDDEzs6FR3dfe6DnyQT6z386Y852OT98ofbNjNeuL7/Y1dM3ezMys5FzYm5mZlZwLezMzs5Jzn72ZmbWtl33hnWinr7zf49MXlWZsfDMzMxsOLuzNzMxKzs34ZmYTVDtN2u00zXfaRD2oY5eZa/ZmZmYl58LezMys5FzYm5mZlZz77M3MJqhu9l/3qv/fusM1ezMzs5JzYW9mZlZyLuzNzMxKThEx6BiakvQ4sBR4YtCx1LABjqsd9eLaMiJe1e9gzMrMeee4jFpcLeWdI1HYA0iaExE7DzqOao6rPcMal1lZDet3znG1p9O43IxvZmZWci7szczMSm6UCvszBx1AHY6rPcMal1lZDet3znG1p6O4RqbP3szMzMZnlGr2ZmZmNg5DX9hL2kfSXZLulXTcAOPYXNL1ku6QdJuko/P69SRdLeme/O+6A4pvkqTfS7oiL28laXaO64eSpgwgpnUkXSTpznzddh2W62VWds47W45vQuSdQ13YS5oE/AuwL7A98CFJ2w8onGXAZyLidcAuwFE5luOAayNiW+DavDwIRwN3FJZPBk7LcS0EPjaAmP438POIeC3wxhzfsFwvs9Jy3tmWiZF3RsTQvoBdgasKy8cDxw86rhzLZcBewF3AJnndJsBdA4hlWv7PfzdwBSDS4AuTa13HPsW0FvAA+b6QwvqBXy+//Cr7y3lny7FMmLxzqGv2wGbA3MLyvLxuoCRNB3YCZgMbRcQCgPzvhgMI6XTg88DyvLw+sCgiKlNJDeK6vRp4HPh/uYnse5JWZziul1nZOe9szYTJO4e9sFeNdQN9fEDSGsDFwKcjYskgY8nx7Ac8FhG/K66ukbTf120y8CbgjIjYiTRkp5vszfpjGPKAMZx3tqwneeewF/bzgM0Ly9OA+QOKBUmrkv5Yz4+IS/LqRyVtkrdvAjzW57B2A94v6UFgFqk56nRgHUmVSaMHcd3mAfMiYnZevoj0Bzzo62U2ETjvbG5C5Z3DXtj/Ftg23x05BTgYuHwQgUgScBZwR0ScWth0OTAjv59B6o/qm4g4PiKmRcR00vW5LiIOAa4HDhxgXI8AcyW9Jq/aA7idAV8vswnCeWcTEy3vHPpBdSS9l/RraxLw/Yj4+oDieDtwE/BHVvTvnEDqe7oQ2AJ4GDgoIp4cUIy7A5+NiP0kvZr0a3U94PfARyLi+T7HsyPwPWAKcD9wOOkH5lBcL7Myc97ZVoy7U/K8c+gLezMzM+vMsDfjm5mZWYdc2JuZmZWcC3szM7OSc2FvZmZWci7szczMSs6F/RDJs0M9IGm9vLxuXt5y0LGZmQ0j55utcWE/RCJiLnAGcFJedRJwZkQ8NLiozMyGl/PN1vg5+yGTh5X8HfB94Ahgp4h4YbBRmZkNL+ebzU1unsT6KSJelPQ54OfAe/wHa2bWmPPN5tyMP5z2BRYArx90IGZmI8L5ZgMu7IdMHhN5L2AX4L9XZjkyM7PanG8258J+iOTZoc4gzff8MPAt4JTBRmVmNrycb7bGhf1wOQJ4OCKuzsvfAV4r6V0DjMnMbJg532yB78Y3MzMrOdfszczMSs6FvZmZWcm5sDczMys5F/ZmZmYl58LezMys5FzYm5mZlZwLezMzs5JzYW9mZlZy/wnwUSqKXTE11AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 792x288 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if preprocess != False:\n",
    "    draw_one_sample(train_true, train_reco, preprocess,\n",
    "                    min_true=min_true, max_true=max_true, \n",
    "                    min_reco=min_reco, max_reco=max_reco,\n",
    "                    save=False, PATH=PATH\n",
    "                   )\n",
    "else:\n",
    "    draw_one_sample(train_true,train_reco)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def HCAL():\n",
    "\n",
    "    \n",
    "    tf.reset_default_graph()\n",
    "    \n",
    "    _, n_H_A, n_W_A ,n_C = train_true.shape\n",
    "    _, n_H_B, n_W_B ,n_C = train_reco.shape\n",
    "    \n",
    "    gan = bicycle_GAN(n_H_A, n_W_A, n_H_B, n_W_B, n_C,\n",
    "                   min_true=min_true, max_true=max_true, \n",
    "                   min_reco=min_reco, max_reco=max_reco,\n",
    "                   d_sizes=d_sizes, g_sizes_enc=g_sizes_enc, g_sizes_dec=g_sizes_dec, e_sizes=e_sizes,\n",
    "                   lr=LEARNING_RATE, beta1=BETA1,\n",
    "                   preprocess=preprocess, cost_type=cost_type,\n",
    "                   cycl_weight=cycl_weight, latent_weight=latent_weight, kl_weight=kl_weight,\n",
    "                   discr_steps=discr_steps, gen_steps=gen_steps, vae_steps=vae_steps,\n",
    "                   batch_size=BATCH_SIZE, epochs=EPOCHS,\n",
    "                   save_sample=SAVE_SAMPLE_PERIOD, path=PATH, seed= SEED)\n",
    "    \n",
    "    vars_to_train= tf.trainable_variables()\n",
    "        \n",
    "    if task == 'TRAIN':\n",
    "        \n",
    "        init_op = tf.global_variables_initializer()\n",
    "        \n",
    "        \n",
    "    if task == 'TEST':\n",
    "        \n",
    "        vars_all = tf.global_variables()\n",
    "        vars_to_init = list(set(vars_all)-set(vars_to_train))\n",
    "        init_op = tf.variables_initializer(vars_to_init)\n",
    "        \n",
    "    saver=tf.train.Saver()\n",
    "    # Add ops to save and restore all the variables.\n",
    "    \n",
    "    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.33)\n",
    "    \n",
    "    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:\n",
    "        \n",
    "        sess.run(init_op)\n",
    "        \n",
    "        if task=='TRAIN':\n",
    "            \n",
    "            if os.path.exists(PATH+'/'+PATH+'bicycle.ckpt.index'):\n",
    "                saver.restore(sess,PATH+'/'+PATH+'bicycle.ckpt')\n",
    "                print('Model restored.')\n",
    "                \n",
    "            gan.set_session(sess)\n",
    "            gan.fit(train_true,train_reco, validating_size)\n",
    "            \n",
    "            save_all = saver.save(sess, PATH+'/'+PATH+'bicycle.ckpt')\n",
    "            print(\"Model saved in path: %s\" % save_all)\n",
    "                  \n",
    "                       \n",
    "        if task=='TEST':\n",
    "            \n",
    "            print('\\n Evaluate model on test set...')\n",
    "            \n",
    "            if os.path.exists(PATH+'/'+PATH+'bicycle.ckpt.index'):\n",
    "                saver.restore(sess, PATH+'/'+PATH+'bicycle.ckpt')\n",
    "                \n",
    "            print('Model restored.')\n",
    "            \n",
    "            gan.set_session(sess)\n",
    "        \n",
    "        #test_reco_NN=gan.get_samples_A_to_B(test_true.reshape(test_true.shape[0],n_H_A,n_W_A,n_C))\n",
    "        test_reco_NN=np.zeros_like(test_true)\n",
    "        t0 = datetime.now()\n",
    "        for i in range(len(test_true)):\n",
    "            test_reco_NN[i]=gan.get_sample_A_to_B(test_true[i].reshape(1,n_H_A,n_W_A,n_C))\n",
    "        per_evt_time=(datetime.now() - t0)/len(test_reco)\n",
    "        print('Per event simulation time {0}'.format(per_evt_time))\n",
    "        done = False\n",
    "\n",
    "        while not done:\n",
    "            \n",
    "            #j = int(input(\"Input event number\"))\n",
    "            if preprocess:\n",
    "                draw_nn_sample(test_true, test_reco, 1, preprocess,\n",
    "                              min_true=min_true, max_true=max_true, \n",
    "                              min_reco=min_reco, max_reco=max_reco,\n",
    "                              f=gan.get_sample_A_to_B, save=False, is_training=False, PATH=PATH)\n",
    "            else:\n",
    "                draw_nn_sample(test_true, test_reco, 1, preprocess,\n",
    "                              f=gan.get_sample_A_to_B, save=False, is_training=False)\n",
    "            \n",
    "            ans = input(\"Generate another?\")\n",
    "            if ans and ans[0] in ('n' or 'N'):\n",
    "                done = True\n",
    "        \n",
    "        done = False\n",
    "        while not done:\n",
    "            \n",
    "            if preprocess:\n",
    "                draw_nn_sample(test_true, test_reco, 20, preprocess,\n",
    "                              min_true=min_true, max_true=max_true, \n",
    "                              min_reco=min_reco, max_reco=max_reco,\n",
    "                              f=gan.get_sample_A_to_B, save=False, is_training=False)\n",
    "            else:\n",
    "                draw_nn_sample(test_true, test_reco, 20, preprocess,\n",
    "                              f=gan.get_sample_A_to_B, save=False, is_training=False)\n",
    "                \n",
    "            ans = input(\"Generate another?\")\n",
    "            if ans and ans[0] in ('n' or 'N'):\n",
    "                done = True\n",
    "        \n",
    "        return test_reco_NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Convolutional Network architecture detected for discriminator B\n",
      "Convolutional Network architecture detected for encoder B\n",
      "Encoder_B\n",
      "Convolution\n",
      "Input for convolution shape  (?, 52, 64, 1)\n",
      "Encoder output shape (?, 16)\n",
      "Generator_A_to_B\n",
      "Input for generator encoded shape (?, 52, 64, 1)\n",
      "Output of generator encoder, \n",
      " and input for generator decoder shape (?, 1, 1, 512)\n",
      "Generator output shape (?, 52, 64, 1)\n",
      "Generator_A_to_B\n",
      "Input for generator encoded shape (?, 52, 64, 1)\n",
      "Output of generator encoder, \n",
      " and input for generator decoder shape (?, 1, 1, 512)\n",
      "Generator output shape (?, 52, 64, 1)\n",
      "Encoder_B\n",
      "Convolution\n",
      "Input for convolution shape  (?, 52, 64, 1)\n",
      "Encoder output shape (?, 16)\n",
      "Discriminator_B\n",
      "Input for convolution shape  (?, 52, 64, 1)\n",
      "minibatch features shape (?, 10)\n",
      "Feature output shape (?, 16)\n",
      "Logits shape (?, 1)\n",
      "Discriminator_B\n",
      "Input for convolution shape  (?, 52, 64, 1)\n",
      "minibatch features shape (?, 10)\n",
      "Feature output shape (?, 16)\n",
      "Logits shape (?, 1)\n",
      "Discriminator_B\n",
      "Input for convolution shape  (?, 52, 64, 1)\n",
      "minibatch features shape (?, 10)\n",
      "Feature output shape (?, 16)\n",
      "Logits shape (?, 1)\n",
      "Generator_A_to_B\n",
      "Input for generator encoded shape (?, 52, 64, 1)\n",
      "Output of generator encoder, \n",
      " and input for generator decoder shape (?, 1, 1, 512)\n",
      "Generator output shape (?, 52, 64, 1)\n",
      "\n",
      " ****** \n",
      "\n",
      "Training bicycleGAN with a total of 16767 samples distributed in 1047 batches of size 16\n",
      "\n",
      "The validation set consists of 1000 images\n",
      "The learning rate is 0.0002, and every 200 batches a generated sample will be saved to HCAL_bycicleGAN_test35\n",
      "\n",
      " ****** \n",
      "\n",
      "Epoch: 0\n",
      "At iter: 200  -  dt: 0:00:00.260050 - d_acc: 0.59, - d_acc_enc: 0.59\n",
      "Discriminator cost 54.03, Generator cost 5920, VAE Cost 3509, KL divergence cost 545.2\n",
      "Saving a sample...\n",
      "At iter: 400  -  dt: 0:00:00.371186 - d_acc: 0.47, - d_acc_enc: 0.50\n",
      "Discriminator cost 73.59, Generator cost 2396, VAE Cost 2247, KL divergence cost 198.6\n",
      "Saving a sample...\n",
      "At iter: 600  -  dt: 0:00:00.378116 - d_acc: 0.66, - d_acc_enc: 0.62\n",
      "Discriminator cost 49.52, Generator cost 686.4, VAE Cost 1286, KL divergence cost 67.42\n",
      "Saving a sample...\n",
      "At iter: 800  -  dt: 0:00:00.373615 - d_acc: 0.53, - d_acc_enc: 0.50\n",
      "Discriminator cost 41.43, Generator cost 315.8, VAE Cost 1015, KL divergence cost 99.63\n",
      "Saving a sample...\n",
      "At iter: 1000  -  dt: 0:00:00.367530 - d_acc: 0.56, - d_acc_enc: 0.59\n",
      "Discriminator cost 41.54, Generator cost 297, VAE Cost 742.2, KL divergence cost 53.33\n",
      "Saving a sample...\n",
      "Printing validation set histograms at epoch 0\n",
      "ET Distribution plots are being printed...\n",
      "Done\n",
      "Resolution plots are being printed...\n",
      "Done\n",
      "Epoch: 1\n",
      "At iter: 1200  -  dt: 0:00:00.298434 - d_acc: 0.72, - d_acc_enc: 0.72\n",
      "Discriminator cost 41.07, Generator cost 247.4, VAE Cost 633.6, KL divergence cost 39.22\n",
      "Saving a sample...\n",
      "At iter: 1400  -  dt: 0:00:00.373918 - d_acc: 0.56, - d_acc_enc: 0.47\n",
      "Discriminator cost 35.16, Generator cost 144.5, VAE Cost 616, KL divergence cost 41.35\n",
      "Saving a sample...\n",
      "At iter: 1600  -  dt: 0:00:00.370199 - d_acc: 0.44, - d_acc_enc: 0.44\n",
      "Discriminator cost 47.92, Generator cost 110.3, VAE Cost 484.1, KL divergence cost 36.05\n",
      "Saving a sample...\n",
      "At iter: 1800  -  dt: 0:00:00.366162 - d_acc: 0.53, - d_acc_enc: 0.47\n",
      "Discriminator cost 44.97, Generator cost 99.2, VAE Cost 376.2, KL divergence cost 30.15\n",
      "Saving a sample...\n",
      "At iter: 2000  -  dt: 0:00:00.372570 - d_acc: 0.53, - d_acc_enc: 0.41\n",
      "Discriminator cost 38.9, Generator cost 108.8, VAE Cost 383.8, KL divergence cost 27.91\n",
      "Saving a sample...\n",
      "Printing validation set histograms at epoch 1\n",
      "ET Distribution plots are being printed...\n",
      "Done\n",
      "Resolution plots are being printed...\n",
      "Done\n",
      "Epoch: 2\n",
      "At iter: 2200  -  dt: 0:00:00.287552 - d_acc: 0.66, - d_acc_enc: 0.53\n",
      "Discriminator cost 18.25, Generator cost 87.42, VAE Cost 334.6, KL divergence cost 28.26\n",
      "Saving a sample...\n",
      "At iter: 2400  -  dt: 0:00:00.366586 - d_acc: 0.44, - d_acc_enc: 0.44\n",
      "Discriminator cost 34.34, Generator cost 82.36, VAE Cost 353.6, KL divergence cost 29.78\n",
      "Saving a sample...\n",
      "At iter: 2600  -  dt: 0:00:00.362463 - d_acc: 0.50, - d_acc_enc: 0.50\n",
      "Discriminator cost 38.3, Generator cost 79.64, VAE Cost 326, KL divergence cost 27.11\n",
      "Saving a sample...\n",
      "At iter: 2800  -  dt: 0:00:00.363595 - d_acc: 0.47, - d_acc_enc: 0.47\n",
      "Discriminator cost 33.89, Generator cost 79.25, VAE Cost 246, KL divergence cost 26.67\n",
      "Saving a sample...\n",
      "At iter: 3000  -  dt: 0:00:00.373816 - d_acc: 0.56, - d_acc_enc: 0.53\n",
      "Discriminator cost 16.85, Generator cost 80.94, VAE Cost 293.6, KL divergence cost 27.16\n",
      "Saving a sample...\n",
      "Printing validation set histograms at epoch 2\n",
      "ET Distribution plots are being printed...\n",
      "Done\n",
      "Resolution plots are being printed...\n",
      "Done\n",
      "Epoch: 3\n",
      "At iter: 3200  -  dt: 0:00:00.300006 - d_acc: 0.53, - d_acc_enc: 0.53\n",
      "Discriminator cost 23.1, Generator cost 82.69, VAE Cost 251.7, KL divergence cost 26.43\n",
      "Saving a sample...\n",
      "At iter: 3400  -  dt: 0:00:00.373295 - d_acc: 0.56, - d_acc_enc: 0.50\n",
      "Discriminator cost 16.79, Generator cost 72.6, VAE Cost 221.8, KL divergence cost 26.54\n",
      "Saving a sample...\n",
      "At iter: 3600  -  dt: 0:00:00.372298 - d_acc: 0.38, - d_acc_enc: 0.38\n",
      "Discriminator cost 38.19, Generator cost 71.64, VAE Cost 212.6, KL divergence cost 26.76\n",
      "Saving a sample...\n",
      "At iter: 3800  -  dt: 0:00:00.369248 - d_acc: 0.38, - d_acc_enc: 0.44\n",
      "Discriminator cost 17.12, Generator cost 63.71, VAE Cost 223.1, KL divergence cost 25.36\n",
      "Saving a sample...\n",
      "At iter: 4000  -  dt: 0:00:00.367608 - d_acc: 0.50, - d_acc_enc: 0.44\n",
      "Discriminator cost 21.59, Generator cost 84.74, VAE Cost 213.5, KL divergence cost 26.49\n",
      "Saving a sample...\n",
      "Printing validation set histograms at epoch 3\n",
      "ET Distribution plots are being printed...\n",
      "Done\n",
      "Resolution plots are being printed...\n",
      "Done\n",
      "Model saved in path: HCAL_bycicleGAN_test35/HCAL_bycicleGAN_test35bicycle.ckpt\n",
      "Per event simulation time 0:00:00.001866\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x576 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if __name__=='__main__':\n",
    "    \n",
    "    if task == 'TRAIN':\n",
    "        if not os.path.exists(PATH):\n",
    "            os.mkdir(PATH)\n",
    "            \n",
    "        elif os.path.exists(PATH):\n",
    "            if os.path.exists(PATH+'/checkpoint'):\n",
    "                ans = input('A previous checkpoint already exists, choose the action to perform \\n \\n 1) Overwrite the current model saved at '+PATH+'/checkpoint \\n 2) Start training a new model \\n 3) Restore and continue training the previous model \\n ')\n",
    "                \n",
    "                if ans == '1':\n",
    "                    print('Overwriting existing model in '+PATH)\n",
    "                    for file in os.listdir(PATH):\n",
    "                        file_path = os.path.join(PATH, file)\n",
    "                        try:\n",
    "                            if os.path.isfile(file_path):\n",
    "                                os.unlink(file_path)\n",
    "                            #elif os.path.isdir(file_path): shutil.rmtree(file_path)\n",
    "                        except Exception as e:\n",
    "                            print(e)\n",
    "                            \n",
    "                elif ans == '2':\n",
    "                    PATH = input('Specify the name of the model, a new directory will be created.\\n')\n",
    "                    os.mkdir(PATH)    \n",
    "        \n",
    "        test_reco_NN = HCAL()\n",
    "   \n",
    "    elif task == 'TEST':\n",
    "        if not os.path.exists(PATH+'/checkpoint'):\n",
    "            print('No checkpoint to test')\n",
    "        else:\n",
    "            test_reco_NN = HCAL()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_MC=train_reco.max()\n",
    "max_NN=test_reco_NN.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Max NN {0}, Max_MC {1}'.format(max_NN, max_MC, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_reco_NN_rescaled=(test_reco_NN/test_reco_NN.max())*max_MC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(test_reco_NN_rescaled.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if preprocess:\n",
    "    reco_MC_hist = denormalise(test_reco, min_reco, max_reco).reshape(test_reco.shape[0], test_reco.shape[1]*test_reco.shape[2])\n",
    "    reco_MC_hist = np.sum(reco_MC_hist,axis=1)\n",
    "    max_E=np.max(reco_MC_hist)\n",
    "    \n",
    "    reco_NN_hist = denormalise(test_reco_NN, min_reco, max_reco).reshape(test_reco_NN.shape[0], test_reco_NN.shape[1]*test_reco_NN.shape[2])\n",
    "    reco_NN_hist = np.sum(reco_NN_hist,axis=1)\n",
    "    max_NN = np.max(reco_NN_hist)\n",
    "    \n",
    "    true_hist = denormalise(test_true, min_reco, max_true).reshape(test_true.shape[0], test_true.shape[1]*test_true.shape[2])\n",
    "    true_hist = np.sum(true_hist,axis=1)\n",
    "    max_true_E=np.max(true_hist)\n",
    "else:\n",
    "    reco_MC_hist = test_reco.reshape(test_reco.shape[0], test_reco.shape[1]*test_reco.shape[2])\n",
    "    reco_MC_hist = np.sum(reco_MC_hist,axis=1)\n",
    "    max_E=np.max(reco_MC_hist)\n",
    "    \n",
    "    reco_NN_hist = test_reco_NN_rescaled.reshape(test_reco_NN_rescaled.shape[0], test_reco_NN.shape[1]*test_reco_NN.shape[2])\n",
    "    reco_NN_hist = np.sum(reco_NN_hist,axis=1)\n",
    "    max_NN = np.max(reco_NN_hist)\n",
    "    \n",
    "    true_hist = test_true.reshape(test_true.shape[0], test_true.shape[1]*test_true.shape[2])\n",
    "    true_hist = np.sum(true_hist,axis=1)\n",
    "    max_true_E=np.max(true_hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reco_NN_hist=(reco_NN_hist/reco_NN_hist.max())*reco_MC_hist.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#if preprocess:\n",
    "#    reco_MC_hist = denormalise(train_reco, min_reco, max_reco).reshape(train_reco.shape[0], train_reco.shape[1]*train_reco.shape[2])\n",
    "#    reco_MC_hist = np.sum(reco_MC_hist,axis=1)\n",
    "#    max_E=np.max(reco_MC_hist)\n",
    "#    \n",
    "#    reco_NN_hist = denormalise(test_reco_NN, min_reco, max_reco).reshape(test_reco_NN.shape[0], test_reco_NN.shape[1]*test_reco_NN.shape[2])\n",
    "#    reco_NN_hist = np.sum(reco_NN_hist,axis=1)\n",
    "#    max_NN = np.max(reco_NN_hist)\n",
    "#    \n",
    "#    true_hist = denormalise(train_true, min_true, max_true).reshape(train_true.shape[0], train_true.shape[1]*train_true.shape[2])\n",
    "#    true_hist = np.sum(true_hist,axis=1)\n",
    "#    max_true_E=np.max(true_hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diffNN = reco_NN_hist-true_hist\n",
    "diffMC = reco_MC_hist-true_hist\n",
    "\n",
    "plt.subplot(1,2,1)\n",
    "plt.tick_params(labelsize=15);\n",
    "h_reco = plt.hist(diffMC,bins=30, edgecolor='black');\n",
    "plt.xlabel('ET recoMC - ET true', fontsize=15)\n",
    "plt.ylabel('dN/dETdiff', fontsize=15)\n",
    "plt.title('Resolution as simulated by MC', fontsize=15)\n",
    "plt.subplot(1,2,2)\n",
    "plt.tick_params(labelsize=15);\n",
    "h_nn = plt.hist(diffNN,bins=30, edgecolor='black');\n",
    "plt.xlabel('ET recoNN - ET true', fontsize=15)\n",
    "plt.ylabel('dN/dETdiff', fontsize=15)\n",
    "plt.title('Resolution as simulated by NN', fontsize=15)\n",
    "fig = plt.gcf()\n",
    "fig.set_size_inches(12,4)\n",
    "plt.savefig(PATH+'/resolution.eps', format='eps', dpi=100)\n",
    "\n",
    "\n",
    "plt.hist(diffNN, bins=30);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diffNN.mean(), diffNN.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diffMC.mean(), diffMC.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(true_hist==0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplot(1,3,1)\n",
    "plt.tick_params(labelsize=12);\n",
    "h_reco = plt.hist(true_hist/1000,bins=30, edgecolor='black');\n",
    "plt.xlabel('E_T (GeV)', fontsize=15)\n",
    "plt.ylabel('dN/dE_T', fontsize=15)\n",
    "plt.title('Pion True E_T', fontsize=15)\n",
    "plt.subplot(1,3,2)\n",
    "plt.tick_params(labelsize=12);\n",
    "h_reco = plt.hist(reco_MC_hist/1000,bins=30, edgecolor='black');\n",
    "plt.xlabel('E_T (GeV)', fontsize=15)\n",
    "\n",
    "plt.title('Pion Reco E_T from MC', fontsize=15)\n",
    "plt.subplot(1,3,3)\n",
    "plt.tick_params(labelsize=12);\n",
    "h_nn = plt.hist(reco_NN_hist/1000,bins=30, edgecolor='black');\n",
    "plt.xlabel('E_T (GeV)', fontsize=15)\n",
    "\n",
    "plt.title('Pion Reco E_T from BicycleGAN', fontsize=15)\n",
    "fig = plt.gcf()\n",
    "fig.set_size_inches(16,4)\n",
    "plt.savefig(PATH+'/distribution.eps', format='eps', dpi=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff=plt.bar(np.arange(0, 30), \n",
    "             height=(h_nn[0]-h_reco[0]), edgecolor='black', \n",
    "             linewidth=1, color='lightblue',width = 1, align = 'edge') \n",
    "plt.xlabel('E (GeV)')\n",
    "plt.ylabel('dN/dE')\n",
    "plt.title(\"NN output - MC output\")\n",
    "fig = plt.gcf()\n",
    "fig.set_size_inches(12,4)\n",
    "plt.savefig(PATH+'/difference.eps', format='eps',dpi=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "four_cells_diff_NN= np.array([\n",
    "       four_cells(test_reco_NN_rescaled[i]).sum() - test_true[i].sum()  for i in range(len(test_reco))\n",
    "])    \n",
    "four_cells_diff_MC= np.array([\n",
    "       four_cells(test_reco[i]).sum() - test_true[i].sum() for i in range(len(test_reco))\n",
    "])\n",
    "plt.hist(four_cells_diff_NN, bins=30, label = 'NN-ET_true')\n",
    "plt.hist(four_cells_diff_MC, bins=30, label = 'MC-ET_true', histtype='step')\n",
    "plt.legend(loc=2);\n",
    "plt.xlabel('Sum of 4 max cells ET  - ET true (GeV)')\n",
    "plt.ylabel('dN/dET')\n",
    "plt.savefig(PATH+'/four_cells_diff_combined.eps', format='eps', dpi=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "four_cells_diff= four_cells_diff_NN-four_cells_diff_MC\n",
    "plt.hist(four_cells_diff_NN-four_cells_diff_MC, bins=20, label = 'MC-NN');\n",
    "plt.legend();\n",
    "print('four cells diff mean {0}, std {1}'.format(four_cells_diff.mean(), four_cells_diff.std()))\n",
    "plt.savefig(PATH+'/four_cells_diff.eps', format='eps', dpi=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyper_dict={'LEARNING_RATE':LEARNING_RATE,\n",
    "            'BETA1':BETA1,\n",
    "            'BATCH_SIZE':BATCH_SIZE,\n",
    "            'EPOCHS':EPOCHS,\n",
    "            'SAVE_SAMPLE_PERIOD':SAVE_SAMPLE_PERIOD,\n",
    "            'SEED':SEED,\n",
    "            'd_sizes':d_sizes,\n",
    "            'g_sizes_dec':g_sizes_dec,\n",
    "            'g_sizes_enc':g_sizes_enc,\n",
    "            'e_sizes':e_sizes,\n",
    "            'preprocess':preprocess,\n",
    "            'cost_type':cost_type,\n",
    "            'validating_size':validating_size,\n",
    "            'test_size':test_size,\n",
    "            'n_batches':n_batches,\n",
    "            'reco_path':reco_path,\n",
    "            'true_path':true_path,\n",
    "            'discr_steps':discr_steps,\n",
    "            'gen_steps':gen_steps,\n",
    "            'vae_steps':vae_steps,\n",
    "            'latent_weight':latent_weight,\n",
    "            'cycl_weight':cycl_weight,\n",
    "            'kl_weight':kl_weight,\n",
    "           }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(PATH+'/hyper_parameters.pkl', 'wb') as f:  \n",
    "    pickle.dump(hyper_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}