Newer
Older
rnn_bachelor_thesis / 1_to_1_multi_layer.ipynb
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import all packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "from sklearn import preprocessing\n",
    "import pickle as pkl\n",
    "from pathlib import Path\n",
    "from keras.datasets import imdb\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.layers import LSTM\n",
    "from keras.layers import GRU\n",
    "from keras.layers import Dropout, BatchNormalization\n",
    "from keras.layers import ConvLSTM2D\n",
    "from keras.layers import Conv1D\n",
    "#from keras.layers.convolutional import Conv1D\n",
    "#from keras.layers.convolutional import MaxPooling1D\n",
    "from keras.layers.embeddings import Embedding\n",
    "from keras.preprocessing import sequence\n",
    "from keras.callbacks import History\n",
    "from keras.callbacks import EarlyStopping\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "from keras.models import load_model\n",
    "\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "#import seaborn as sns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import the dataset of the matched 8-hit tracks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import data as array\n",
    "# 8 hits with x,y,z\n",
    "\n",
    "testset = pd.read_pickle('matched_8hittracks.pkl')\n",
    "#print(testset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert the data to an array (float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Convert the data\n",
    "\n",
    "tset = np.array(testset)\n",
    "tset = tset.astype('float32')\n",
    "\n",
    "#Check testset with arbitrary particle\n",
    "\n",
    "#print(tset.shape)\n",
    "#for i in range(8):\n",
    "    #print(tset[1,3*i:(3*i+3)])\n",
    "#print(tset[0,:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transformation between original 2D-array into 3D-array\n",
    "\n",
    "#### reshapor()\n",
    "\n",
    "Description:\n",
    "\n",
    "Transforms 2D-array into 3D array\n",
    "\n",
    "Arguments:\n",
    "\n",
    "- arr_orig: Original 2D array\n",
    "- num_inputs: Number of inputs per timestep (default value = 3 for X,Y,Z coordinates)\n",
    "\n",
    "Returns:\n",
    "\n",
    "- arr: 3D-array of shape(particlenumber, timesteps, input = coordinates)\n",
    "\n",
    "#### reshapor_inv()\n",
    "\n",
    "Description:\n",
    "\n",
    "Inverse transformation from 3D-array into 2D-array\n",
    "\n",
    "Arguments:\n",
    "\n",
    "- array_shaped: 3D-array of shape(particlenumber, timesteps, input = coordinates)\n",
    "\n",
    "Returns:\n",
    "\n",
    "- arr: 2D-array of shape(particlenumber, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Reshapes the 2D-array to a 3D-array\n",
    "\n",
    "def reshapor(arr_orig, num_inputs=3):\n",
    "    timesteps = int(arr_orig.shape[1]/num_inputs)\n",
    "    number_examples = int(arr_orig.shape[0])\n",
    "    arr = np.zeros((number_examples, timesteps, num_inputs))\n",
    "    \n",
    "    for i in range(number_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,t,:] = arr_orig[i,num_inputs*t:num_inputs*t+num_inputs]\n",
    "        \n",
    "    return arr\n",
    "\n",
    "#The inverse transformation of the reshapor function (3D to 2D)\n",
    "\n",
    "def reshapor_inv(array_shaped):\n",
    "    num_inputs = array_shaped.shape[2]\n",
    "    timesteps = int(array_shaped.shape[1])\n",
    "    num_examples = int(array_shaped.shape[0])\n",
    "    arr = np.zeros((num_examples, timesteps*num_inputs))\n",
    "    \n",
    "    for i in range(num_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,num_inputs*t:num_inputs*t+num_inputs] = array_shaped[i,t,:]\n",
    "        \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create random training and test sets from the data\n",
    "\n",
    "#### create_random_sets()\n",
    "\n",
    "Description:\n",
    "\n",
    "Splits an dataset into a train and a test set\n",
    "\n",
    "\n",
    "Input:\n",
    "\n",
    "- dataset: The actual dataset with shape (particles, other dimensions)\n",
    "- train_to_total_ratio: The ratio that the training-set should be out of the original set.\n",
    "    The remaining part will become the test-set\n",
    "    \n",
    "\n",
    "Returns:\n",
    "\n",
    "- train_set: The newly created training set (particles, other dimensions)\n",
    "- test_set: The newly created test set (particles, other dimensions)\n",
    " \n",
    " \n",
    "Additional comments:\n",
    "\n",
    "The data will be randomly shuffled before it gets split up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "### create the training set and the test set###\n",
    "\n",
    "def create_random_sets(dataset, train_to_total_ratio):\n",
    "    #shuffle the dataset\n",
    "    num_examples = dataset.shape[0]\n",
    "    p = np.random.permutation(num_examples)\n",
    "    dataset = dataset[p,:]\n",
    "    \n",
    "    #evaluate size of training and test set and initialize them\n",
    "    train_set_size = np.int(num_examples*train_to_total_ratio)\n",
    "    test_set_size = num_examples - train_set_size\n",
    "    \n",
    "    train_set = np.zeros((train_set_size, dataset.shape[1]))\n",
    "    test_set = np.zeros((test_set_size, dataset.shape[1]))\n",
    "   \n",
    "\n",
    "    #fill train and test sets\n",
    "    for i in range(num_examples):\n",
    "        if train_set_size > i:\n",
    "            train_set[i,:] += dataset[i,:]\n",
    "        else:\n",
    "            test_set[i - train_set_size,:]  += dataset[i,:]\n",
    "                \n",
    "    return train_set, test_set\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create the training and test-sets\n",
    "\n",
    "train_set, test_set = create_random_sets(tset, 0.9)\n",
    "\n",
    "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n",
    "#print(test_set[0,:,:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normalization of the data\n",
    "\n",
    "### Normalization on a min_max_scaler from sklearn\n",
    "\n",
    "#### correct_array_steps()\n",
    "\n",
    "Description: As the scaler will be fixed on arrays of specific length this function returns an array padded with zeros with the correct shape\n",
    "\n",
    "Input:\n",
    "\n",
    "- arr: 3D-array of shape(particle_number, timesteps, num_inputs)\n",
    "- steps: Required number of timesteps for the scaler (default value = 8)\n",
    "- num_inputs: Number of inputs per timestep (default value = 3 for X,Y,Z coordinates)\n",
    "\n",
    "Returns:\n",
    "\n",
    "- arr: 3D array of shape(particle_number, steps, num_inputs)\n",
    "\n",
    "#### set_min_max_scaler()\n",
    "\n",
    "Description: Sets the min_max_scaler based on the dataset given (sklearn based)\n",
    "\n",
    "Input:\n",
    "\n",
    "- arr: 2D of shape(particle_number, inputs) or 3D-array of shape(particle_number, timesteps, num_inputs)\n",
    "- feature_range: Tuple which defines the area to which the data should be scaled (default value = (-1,1))\n",
    "\n",
    "Returns:\n",
    "\n",
    "- min_max_scalor: min_max_scaler based of the data given\n",
    "\n",
    "#### min_max_scaler()\n",
    "\n",
    "Description: Transforms a 3D-array with a given min_max_scaler (sklearn based)\n",
    "\n",
    "Input:\n",
    "\n",
    "- arr: 3D-array of shape(particle_number, timesteps, num_inputs)\n",
    "- min_max_scalor: The min_max_scaler used for the transformation (default value: min_max_scalor)\n",
    "\n",
    "Returns:\n",
    "\n",
    "- arr: Transformed 3D-array\n",
    "\n",
    "#### min_max_scaler_inv()\n",
    "\n",
    "Description: Transforms a 3D-array with a given min_max_scaler back to original form (sklearn based)\n",
    "\n",
    "Input:\n",
    "\n",
    "- arr: 3D-array of shape(particle_number, timesteps, num_inputs)\n",
    "- min_max_scalor: The min_max_scaler used for the transformation (default value: min_max_scalor)\n",
    "\n",
    "Returns:\n",
    "\n",
    "- arr: Transformed 3D-array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn\n",
    "def correct_array_steps(arr, steps= 8, num_inputs= 3): #steps > array_steps\n",
    "        if arr.shape[1] != steps:\n",
    "            _ = np.zeros((arr.shape[0], steps, num_inputs))\n",
    "            _[:,:arr.shape[1],:] += arr\n",
    "            arr = _\n",
    "        return arr\n",
    "\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_min_max_scaler(arr, feature_range= (-1,1)):\n",
    "    min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = min_max_scalor.fit_transform(arr)\n",
    "    return min_max_scalor\n",
    "\n",
    "min_max_scalor = set_min_max_scaler(train_set)\n",
    "\n",
    "\n",
    "#transform data\n",
    "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n",
    "    num_inputs = arr.shape[2]\n",
    "    arr = correct_array_steps(arr)\n",
    "    arr = reshapor(min_max_scalor.transform(reshapor_inv(arr)), num_inputs=num_inputs)\n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n",
    "    num_inputs = arr.shape[2]\n",
    "    arr = correct_array_steps(arr)\n",
    "    arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)), num_inputs=num_inputs)\n",
    "    return arr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Normalization based on a standard_scaler from sklearn\n",
    "\n",
    "\n",
    "#### set_std_scaler()\n",
    "\n",
    "Description: Sets the std_scaler based on the dataset given (sklearn based)\n",
    "\n",
    "Input:\n",
    "\n",
    "- arr: 2D of shape(particle_number, inputs) or 3D-array of shape(particle_number, timesteps, num_inputs)\n",
    "- feature_range: Tuple which defines the area to which the data should be scaled (default value = (-1,1))\n",
    "\n",
    "Returns:\n",
    "\n",
    "- std_scaler: std_scaler based of the data given\n",
    "\n",
    "#### std_scaler()\n",
    "\n",
    "Description: Transforms a 3D-array with a given std_scaler (sklearn based)\n",
    "\n",
    "Input:\n",
    "\n",
    "- arr: 3D-array of shape(particle_number, timesteps, num_inputs)\n",
    "- std_scaler: The std_scaler used for the transformation (default value: std_scaler)\n",
    "\n",
    "Returns:\n",
    "\n",
    "- arr: Transformed 3D-array\n",
    "\n",
    "#### std_scaler_inv()\n",
    "\n",
    "Description: Transforms a 3D-array with a given std_scaler back to original form (sklearn based)\n",
    "\n",
    "Input:\n",
    "\n",
    "- arr: 3D-array of shape(particle_number, timesteps, num_inputs)\n",
    "- min_max_scalor: The std_scaler used for the transformation (default value: std_scaler)\n",
    "\n",
    "Returns:\n",
    "\n",
    "- arr: Transformed 3D-array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn - Standard scaler\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_std_scaler(arr):\n",
    "    std_scalor = preprocessing.StandardScaler()\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(std_scalor.fit_transform(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = std_scalor.fit_transform(arr)\n",
    "    return std_scalor\n",
    "\n",
    "std_scalor = set_std_scaler(train_set)\n",
    "\n",
    "#transform data\n",
    "def std_scaler(arr, std_scalor= std_scalor, num_inputs=3):\n",
    "    arr = correct_array_steps(arr)\n",
    "    arr = reshapor(std_scalor.transform(reshapor_inv(arr)))\n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def std_scaler_inv(arr, std_scalor= std_scalor, num_inputs=3):\n",
    "    arr = correct_array_steps(arr)\n",
    "    arr = reshapor(std_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "    return arr\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#reshape the data\n",
    "\n",
    "train_set = reshapor(train_set)\n",
    "test_set = reshapor(test_set)\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Scale data either with MinMax scaler or with Standard scaler\n",
    "#Return scalor if fit = True and and scaled array otherwise\n",
    "\n",
    "def scaler(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"minmax\", scalor = False):\n",
    "    \n",
    "    if scalor != False:\n",
    "        arr = correct_array_steps(arr)\n",
    "        arr = reshapor(scalor.transform(reshapor_inv(arr)))\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"std\":\n",
    "        arr = std_scaler(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler(arr, min_max_scalor= min_max_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n",
    "\n",
    "def scaler_inv(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\", scalor = False, num_inputs= 3):\n",
    "\n",
    "    if scalor != False:\n",
    "        arr = correct_array_steps(arr)\n",
    "        arr = reshapor(scalor.inverse_transform(reshapor_inv(arr)))\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"std\":\n",
    "        arr = std_scaler_inv(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler_inv(arr, min_max_scalor= min_max_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scale the data\n",
    "\n",
    "func = \"minmax\"\n",
    "\n",
    "train_set = scaler(train_set, scalerfunc = func)\n",
    "test_set = scaler(test_set, scalerfunc = func)\n",
    "\n",
    "if func == \"minmax\":\n",
    "    scalor = min_max_scalor\n",
    "elif func == \"std\":\n",
    "    scalor = std_scalor\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "###create random mini_batches###\n",
    "\n",
    "\n",
    "def unison_shuffled_copies(a, b):\n",
    "    assert a.shape[0] == b.shape[0]\n",
    "    p = np.random.permutation(a.shape[0])\n",
    "    return a[p,:,:], b[p,:,:]\n",
    "\n",
    "def random_mini_batches(inputt, target, minibatch_size = 100):\n",
    "    \n",
    "    num_examples = inputt.shape[0]\n",
    "    \n",
    "    \n",
    "    #Number of complete batches\n",
    "    \n",
    "    number_of_batches = int(num_examples/minibatch_size)\n",
    "    minibatches = []\n",
    "   \n",
    "    #shuffle particles\n",
    "    _i, _t = unison_shuffled_copies(inputt, target)\n",
    "    #print(_t.shape)\n",
    "        \n",
    "    \n",
    "    for i in range(number_of_batches):\n",
    "        \n",
    "        minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatches.append((minibatch_train, minibatch_true))\n",
    "        \n",
    "        \n",
    "    minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n",
    "    \n",
    "    \n",
    "    return minibatches\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create random minibatches of train and test set with input and target array\n",
    "\n",
    "\n",
    "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n",
    "#_train, _target = minibatches[0]\n",
    "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n",
    "#print(train[0,:,:], target[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "#minibatches = random_mini_batches(inputt_train, target_train)\n",
    "\n",
    "\n",
    "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n",
    "\n",
    "#print(len(minibatches))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNPlacePrediction():\n",
    "    \n",
    "    \n",
    "    def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\", scalor= scalor):\n",
    "        \n",
    "        self.nsteps = time_steps\n",
    "        self.future_steps = future_steps\n",
    "        self.ninputs = ninputs\n",
    "        self.ncells = ncells\n",
    "        self.num_output = num_output\n",
    "        self._ = cell_type #later used to create folder name\n",
    "        self.__ = activation #later used to create folder name\n",
    "        self.loss_list = []\n",
    "        self.loss_validation = []\n",
    "        self.scalor = scalor\n",
    "        \n",
    "        #### The input is of shape (num_examples, time_steps, ninputs)\n",
    "        #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n",
    "        self.X = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n",
    "        self.Y = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n",
    "\n",
    "        \n",
    "        #Check if activation function valid and set activation\n",
    "        if self.__==\"relu\":\n",
    "            self.activation = tf.nn.relu\n",
    "            \n",
    "        elif self.__==\"tanh\":\n",
    "            self.activation = tf.nn.tanh\n",
    "                    \n",
    "        elif self.__==\"leaky_relu\":\n",
    "            self.activation = tf.nn.leaky_relu\n",
    "            \n",
    "        elif self.__==\"elu\":\n",
    "            self.activation = tf.nn.elu\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn avtivation function: {}\".format(self.__))\n",
    "        \n",
    "        \n",
    "        \n",
    "        #Check if cell type valid and set cell_type\n",
    "        if self._==\"basic_rnn\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicRNNCell\n",
    "            \n",
    "        elif self._==\"lstm\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicLSTMCell\n",
    "                    \n",
    "        elif self._==\"GRU\":\n",
    "            self.cell_type = tf.contrib.rnn.GRUCell\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn cell type: {}\".format(self._))\n",
    "            \n",
    "        \n",
    "        #Check Input of ncells        \n",
    "        if (type(self.ncells) == int):\n",
    "            self.ncells = [self.ncells]\n",
    "        \n",
    "        if (type(self.ncells) != list):\n",
    "            raise ValueError(\"Wrong type of Input for ncells\")\n",
    "        \n",
    "        for _ in range(len(self.ncells)):\n",
    "            if type(self.ncells[_]) != int:\n",
    "                raise ValueError(\"Wrong type of Input for ncells\")\n",
    "                \n",
    "        self.activationlist = []\n",
    "        for _ in range(len(self.ncells)-1):\n",
    "            self.activationlist.append(self.activation)\n",
    "        self.activationlist.append(tf.nn.tanh)\n",
    "        \n",
    "        self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n",
    "                                                 for layer in range(len(self.ncells))])\n",
    "            \n",
    "        \n",
    "        #### I now define the output\n",
    "        self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        self.sess = tf.Session()\n",
    "        \n",
    "    def set_cost_and_functions(self, LR=0.001):\n",
    "        #### I define here the function that unrolls the RNN cell\n",
    "        self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n",
    "        #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n",
    "        self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output))   \n",
    "        \n",
    "        #### the rest proceed as usual\n",
    "        self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n",
    "        #### Variable initializer\n",
    "        self.init = tf.global_variables_initializer()\n",
    "        self.saver = tf.train.Saver()\n",
    "        self.sess.run(self.init)\n",
    "  \n",
    "    \n",
    "    def save(self, rnn_folder=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.save(self.sess, rnn_folder)       \n",
    "            \n",
    "            \n",
    "    def load(self, filename=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.restore(self.sess, filename)\n",
    "\n",
    "               \n",
    "        \n",
    "    def fit(self, minibatches, epochs, print_step, validation_input, validation_output, checkpoint = 5, patience = 20, patience_trigger= 1.5/10**6):\n",
    "        patience_cnt = 0\n",
    "        start = len(self.loss_list)\n",
    "        epoche_save = start\n",
    "        \n",
    "        folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "        \n",
    "        for iep in range(start, start + epochs):\n",
    "            loss = 0\n",
    "            loss_val = 0\n",
    "            \n",
    "            batches = len(minibatches)\n",
    "            #Here I iterate over the batches\n",
    "            for batch in range(batches):\n",
    "            #### Here I train the RNNcell\n",
    "            #### The X is the time series, the Y is shifted by 1 time step\n",
    "                train, target = minibatches[batch]\n",
    "                self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n",
    "                \n",
    "            \n",
    "                loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n",
    "                loss_val += rnn.sess.run(self.cost, feed_dict={self.X:validation_input, rnn.Y:validation_output})\n",
    "            \n",
    "            #Normalize loss over number of batches and scale it back before normaliziation\n",
    "            loss /= batches\n",
    "            loss_val /= batches\n",
    "            self.loss_list.append(loss)\n",
    "            self.loss_validation.append(loss_val)\n",
    "            \n",
    "            #print(loss)\n",
    "            \n",
    "            #Here I create the checkpoint if the perfomance is better\n",
    "            if iep > 1 and iep%checkpoint == 0 and self.loss_validation[iep] < self.loss_validation[epoche_save]:\n",
    "                #print(\"Checkpoint created at epoch: \", iep)\n",
    "                self.save(folder)\n",
    "                epoche_save = iep\n",
    "            \n",
    "            #early stopping with patience\n",
    "            if iep > 1 and abs(self.loss_validation[iep]-self.loss_validation[iep-1]) < patience_trigger:\n",
    "                patience_cnt += 1\n",
    "                #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n",
    "                \n",
    "                if patience_cnt + 1 > patience:\n",
    "                    print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_validation[iep]-self.loss_validation[iep-1]))\n",
    "                    print(\"Cost: \", loss*10**6, \"e-6\")\n",
    "                    print(\"Cost on valdiation_set: \",loss_val*10**6, \"e-6\")\n",
    "                    break\n",
    "            \n",
    "            #Note that the loss here is multiplied with 1000 for easier reading\n",
    "            if iep%print_step==0:\n",
    "                print(\"Epoch number \",iep)\n",
    "                print(\"Cost: \",loss*10**6, \"e-6\")\n",
    "                print(\"Cost on validation_set: \",loss_val*10**6, \"e-6\")\n",
    "                print(\"Patience: \",patience_cnt, \"/\", patience)\n",
    "                print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n",
    "        \n",
    "        #Set model back to the last checkpoint if performance was better\n",
    "        if self.loss_validation[epoche_save] < self.loss_validation[iep]:\n",
    "            self.load(folder)\n",
    "            print(\"\\n\")\n",
    "            print(\"State of last checkpoint checkpoint at epoch \", epoche_save, \" restored\")\n",
    "            print(\"Performance at last checkpoint is \" ,(self.loss_list[iep] - self.loss_list[epoche_save])/self.loss_list[iep]*100, \"% better\" )\n",
    "\n",
    "        \n",
    "        print(\"\\n\")\n",
    "        print(\"Model saved in at: \", folder)\n",
    "            \n",
    "            \n",
    "              \n",
    "        \n",
    "    def predict(self, x):\n",
    "        return self.sess.run(self.output, feed_dict={self.X:x})\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#saves the rnn model and all its parameters including the scaler used\n",
    "#optional also saves the minibatches used to train and the test set\n",
    "\n",
    "def full_save(rnn, train= True, test= True):\n",
    "    folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "    rnn.save(folder)\n",
    "    pkl_name = folder[2:-10] + \".pkl\"\n",
    "    \n",
    "    \n",
    "    pkl_dic = {\"ncells\": rnn.ncells,\n",
    "              \"ninputs\": rnn.ninputs,\n",
    "              \"future_steps\": rnn.future_steps,\n",
    "              \"nsteps\": rnn.nsteps,\n",
    "              \"num_output\": rnn.num_output,\n",
    "              \"cell_type\": rnn._, #cell_type\n",
    "              \"activation\": rnn.__, #Activation\n",
    "              \"loss_list\": rnn.loss_list,\n",
    "              \"scalor\": rnn.scalor,\n",
    "              \"loss_validation\": rnn.loss_validation}\n",
    "    \n",
    "    if train == True:\n",
    "        pkl_dic[\"minibatches\"] = minibatches\n",
    "    \n",
    "    if test == True:\n",
    "        pkl_dic[\"test_input\"] = test_input\n",
    "        pkl_dic[\"test_target\"] = test_target\n",
    "        \n",
    "    pkl.dump( pkl_dic, open(pkl_name , \"wb\" ) )\n",
    "    \n",
    "    print(\"Model saved at: \", folder)\n",
    "    print(\"Remaining data saved as: {}\".format(pkl_name))\n",
    "\n",
    "\n",
    "\n",
    "#loads the rnn model with all its parameters including the scaler used\n",
    "#Checks if the pkl data also contains the training or test sets an return them accordingly\n",
    "def full_load(folder): \n",
    "    #returns state of rnn with all information and returns the train and test set used\n",
    "    \n",
    "    #Directory of pkl file\n",
    "    pkl_name = folder[2:-10] + \".pkl\"\n",
    "    \n",
    "    #Check if pkl file exists\n",
    "    my_file = Path(pkl_name)\n",
    "    if my_file.is_file() == False:\n",
    "        raise ValueError(\"There is no .pkl file with the name: {}\".format(pkl_name))\n",
    "        \n",
    "    pkl_dic = pkl.load( open(pkl_name , \"rb\" ) )\n",
    "    ncells = pkl_dic[\"ncells\"]\n",
    "    ninputs = pkl_dic[\"ninputs\"]\n",
    "    scalor = pkl_dic[\"scalor\"]\n",
    "    future_steps = pkl_dic[\"future_steps\"]\n",
    "    timesteps = pkl_dic[\"nsteps\"] \n",
    "    num_output = pkl_dic[\"num_output\"]\n",
    "    cell_type = pkl_dic[\"cell_type\"]\n",
    "    activation = pkl_dic[\"activation\"]\n",
    "    \n",
    "    #Check if test or trainng set in dictionary\n",
    "    batch = False\n",
    "    test = False\n",
    "    if \"minibatches\" in pkl_dic:\n",
    "        batch = True\n",
    "        minibatches = pkl_dic[\"minibatches\"]\n",
    "    if \"test_input\" in pkl_dic:\n",
    "        test = True\n",
    "        test_input = pkl_dic[\"test_input\"]\n",
    "        test_target = pkl_dic[\"test_target\"]\n",
    "    \n",
    "    #loads and initializes a new model with the exact same properties\n",
    "    \n",
    "    tf.reset_default_graph()\n",
    "    rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=cell_type, activation=activation, scalor=scalor)\n",
    "\n",
    "    rnn.set_cost_and_functions()\n",
    "    \n",
    "    rnn.load(folder)\n",
    "    \n",
    "    rnn.loss_list = pkl_dic[\"loss_list\"]\n",
    "    \n",
    "    rnn.loss_validation = pkl_dic[\"loss_validation\"]\n",
    "    \n",
    "    print(\"Model succesfully loaded\")\n",
    "    \n",
    "    if batch and test:\n",
    "        data = [minibatches, test_input, test_target]\n",
    "        print(\"Minibatches (=training data) and test_input and test_target in data loaded\")\n",
    "        return rnn, data\n",
    "        \n",
    "    elif batch:\n",
    "        data = [minibatches]\n",
    "        print(\"Minibatches (=training data) loaded in data\")\n",
    "        return rnn, data\n",
    "        \n",
    "    elif test:\n",
    "        data = [test_input, test_target]\n",
    "        print(\"test_input and test_target loaded in data\")\n",
    "        return rnn, data\n",
    "    \n",
    "    else:\n",
    "        data = []\n",
    "        print(\"Only Model restored, no trainig or test data found in {}\".format(pkl_name))\n",
    "        print(\"Returned data is empty!\")\n",
    "        return rnn, data\n",
    "\n",
    "#returns the folder name used by full_save and full_load for a given architecture\n",
    "def get_rnn_folder(ncells, cell_type, activation):\n",
    "    folder = \"./rnn_model_\" + cell_type + \"_\" + activation + \"_\" + str(ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "    return folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps = 7\n",
    "future_steps = 1\n",
    "\n",
    "ninputs = 3\n",
    "\n",
    "#ncells as int or list of int\n",
    "ncells = [150, 150, 150]\n",
    "activation = \"leaky_relu\"\n",
    "cell_type = \"lstm\"\n",
    "\n",
    "num_output = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=cell_type, activation=activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "#rnn.set_cost_and_functions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#rnn.fit(minibatches, epochs = 5000, print_step=10, validation_input = test_input, validation_output= test_target)\n",
    "#full_save(rnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#Plot the loss\n",
    "def plot_loss_list(loss_list = rnn.loss_list, loss_validation = rnn.loss_validation):\n",
    "    plt.plot(rnn.loss_list, label='Loss on training set')\n",
    "    plt.plot(rnn.loss_validation,  label='Loss on test set')\n",
    "    plt.legend()\n",
    "    plt.xlabel(\"Epoch\")\n",
    "    plt.ylabel(\"Cost\")\n",
    "    plt.show()\n",
    "\n",
    "#plot_loss_list(rnn.loss_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm_leaky_relu_[150,150,150]c/rnn_basic\n",
      "Model succesfully loaded\n",
      "Minibatches (=training data) and test_input and test_target in data loaded\n"
     ]
    }
   ],
   "source": [
    "folder = get_rnn_folder(ncells = ncells, cell_type = cell_type, activation = activation)\n",
    "rnn, data = full_load(folder)\n",
    "minibatches, test_input, test_target = data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4690, 7, 3)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_input.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rnn_test(rnn, test_input= test_input, test_target= test_target, scalor= rnn.scalor):\n",
    "    \n",
    "    #Here I predict based on my test set\n",
    "    test_pred = rnn.predict(test_input)\n",
    "    \n",
    "    #Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n",
    "    #scaler_inv(test_input, scalerfunc = func)[0,:,:]\n",
    "    diff = scaler_inv(test_pred, scalerfunc = func, scalor= scalor)-scaler_inv(test_target, scalerfunc = func, scalor= scalor)\n",
    "    print(diff[random.randint(0,test_pred.shape[0]),:,:])\n",
    "    \n",
    "    #Here I evaluate my model on the test set based on mean_squared_error\n",
    "    loss = rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})\n",
    "    print(\"Loss on test set:\", loss)\n",
    "    \n",
    "    return test_pred, loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-3.18470338e-01 -1.90207179e-01 -1.40898075e-03]\n",
      " [-1.59599343e+00  1.97988971e-01 -1.11545922e+00]\n",
      " [-7.32249507e-02  4.83462188e-01  6.15174259e-01]\n",
      " [ 3.13725194e+00 -6.94178227e-01 -3.67514878e+00]\n",
      " [ 5.84199409e-01  4.98911750e-01 -4.66140907e-02]\n",
      " [ 1.32929954e+00  5.03890275e-01 -7.67828468e-02]\n",
      " [ 4.41919712e-01 -4.32863707e-01  6.22717514e-02]\n",
      " [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "Loss on test set: 0.0017876212\n"
     ]
    }
   ],
   "source": [
    "test_pred, test_loss = rnn_test(rnn=rnn)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "tset_matched = pd.read_pickle('matched_and_unmatched_8hittracks.pkl')\n",
    "#test = pd.read_pickle('matched_and_unmatched_8hittracks2.pkl')\n",
    "#tset_matched\n",
    "#test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "tset_matched = np.array(tset_matched)\n",
    "tset_matched = tset_matched.astype('float32')\n",
    "truth = tset_matched[:,-1]\n",
    "tset_matched = scaler(reshapor(tset_matched[:,:-1]), scalerfunc = func, scalor= scalor)\n",
    "#print(reshapor_inv(tset_matched).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "tset_matched = reshapor_inv(tset_matched)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "#print(tset_matched.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tracks_to_particle(tset_matched, truth):\n",
    "    start = 0\n",
    "    start_points = [0]\n",
    "    converse = False\n",
    "    \n",
    "    if len(tset_matched.shape) == 3:\n",
    "        tset_matched = reshapor_inv(tset_matched)\n",
    "        converse = True\n",
    "    \n",
    "    for track in range(tset.shape[0]-1):\n",
    "    \n",
    "        for coord in range(12):\n",
    "        \n",
    "            if tset_matched[track, coord] != tset_matched[track+1, coord]:\n",
    "                start = track + 1\n",
    "    \n",
    "        if start != start_points[-1]:\n",
    "            start_points.append(start)\n",
    "\n",
    "    num_part = len(start_points)\n",
    "\n",
    "    particle_tracks = []\n",
    "    track_truth = []\n",
    "\n",
    "    if converse:\n",
    "        tset_matched = reshapor(tset_matched)\n",
    "    \n",
    "    for particle in range(num_part-1):\n",
    "        particle_tracks.append(reshapor(tset_matched[start_points[particle]:start_points[particle+1]]))\n",
    "        track_truth.append(truth[start_points[particle]:start_points[particle+1]])\n",
    "        \n",
    "    \n",
    "    return particle_tracks, track_truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "particle_tracks, track_truth = tracks_to_particle(tset_matched= tset_matched, truth= truth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "#print(particle_tracks[11])\n",
    "#print(track_truth[11])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "#particle_tracks[1][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "#num_particles = len(particle_tracks)\n",
    "#num_tracks = len(particle_tracks[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_best_tracks(particle_tracks):\n",
    "\n",
    "    generated_truth_list = []\n",
    "    loss_list = []\n",
    "    num_particles = len(particle_tracks)\n",
    "    \n",
    "    for particle in range(num_particles):\n",
    "    \n",
    "        num_tracks = len(particle_tracks[particle])\n",
    "        min_loss = 10\n",
    "        part_loss_list = np.zeros((num_tracks))\n",
    "        truth = np.zeros((num_tracks))\n",
    "        \n",
    "        for track in range(num_tracks):\n",
    "            inputt = np.zeros((1,7,3))\n",
    "            inputt[0,:,:] = particle_tracks[particle][track][:-1,:]\n",
    "            \n",
    "            true_pred =  np.zeros((1,7,3))\n",
    "            true_pred[0,:,:] = particle_tracks[particle][track][1:,:]\n",
    "            loss = rnn.sess.run(rnn.cost, feed_dict={rnn.X:inputt, rnn.Y:true_pred})\n",
    "            if loss < min_loss:\n",
    "                min_loss = loss\n",
    "            part_loss_list[track] += loss\n",
    "        \n",
    "        #print(min_loss)\n",
    "        minIndex = np.where(part_loss_list == min_loss)[0]\n",
    "        truth[minIndex] += 1\n",
    "        generated_truth_list.append(truth)\n",
    "        loss_list.append(part_loss_list)\n",
    "        #print(minIndex)\n",
    "    \n",
    "    return generated_truth_list, loss_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "#generated_truth, loss_list = find_best_tracks(particle_tracks=particle_tracks)\n",
    "#print(generated_truth[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_accuracy(generated_truth, track_truth= track_truth):\n",
    "    \n",
    "    num_particles = len(track_truth)\n",
    "    correct_list = []\n",
    "\n",
    "    for particle in range(num_particles):\n",
    "        correct = True\n",
    "        num_tracks = len(particle_tracks[particle])\n",
    "        \n",
    "        for track in range(num_tracks):\n",
    "            if track_truth[particle][track] != generated_truth[particle][track]:\n",
    "                correct = False\n",
    "        \n",
    "        if correct:\n",
    "            correct_list.append(particle)\n",
    "        \n",
    "    accuracy = len(correct_list)/num_particles\n",
    "    \n",
    "    print(\"The right track was chosen:\", accuracy*100, \"% of the time\")\n",
    "    print(len(correct_list), \"particles correctly assigned to their path\")\n",
    "\n",
    "    return correct_list\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "#correct_list = check_accuracy(generated_truth)\n",
    "generated_truth = pkl.load( open(\"generated_truth_\" + folder[2:-10] +\".pkl\" , \"rb\" ) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4425\n"
     ]
    }
   ],
   "source": [
    "#Count tracks that have no 8track path\n",
    "\n",
    "num_particles = len(track_truth)\n",
    "\n",
    "counter = 0\n",
    "\n",
    "for particle in range(num_particles):\n",
    "    \n",
    "    if generated_truth[particle].all() == 0:\n",
    "        counter +=1\n",
    "        \n",
    "print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pkl.dump( generated_truth, open(\"generated_truth_\" + folder[2:-10] +\".pkl\" , \"wb\" ) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_Truth_to_Gen_truth(track_truth, generated_truth, loss_list):\n",
    "    \n",
    "\n",
    "    for particle in range(15, 30):\n",
    "        print()\n",
    "        print(\"Particle: \", particle)\n",
    "    \n",
    "        num_tracks = len(particle_tracks[particle])\n",
    "        \n",
    "        for track in range(num_tracks):\n",
    "            print(\"Truth: \", track_truth[particle][track])\n",
    "            print(\"Gen_truth: \", generated_truth[particle][track])\n",
    "            print(\"Loss: \", loss_list[particle][track])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.71172575  0.69559685  0.27218047]\n",
      " [-0.64037859  0.73127269  0.26876878]\n",
      " [-0.43078365  0.90224016  0.24776119]\n",
      " [-0.38153891  0.92340426  0.25317404]]\n",
      "0.0\n"
     ]
    }
   ],
   "source": [
    "particle_start_array = np.zeros((num_particles,4,3))\n",
    "\n",
    "def create_track_exist_truth(particle_start_array, track_truth):\n",
    "    \n",
    "\n",
    "    for particle in range(num_particles):\n",
    "        particle_start_array[particle,:,:] += particle_tracks[particle][0][:4,:]\n",
    "\n",
    "    print(particle_start_array[11,:,:])\n",
    "\n",
    "    track_exist_truth = np.zeros((num_particles))\n",
    "\n",
    "    for particle in range(num_particles):\n",
    "        correct = False\n",
    "        num_tracks = len(track_truth[particle])\n",
    "    \n",
    "        for track in range(num_tracks):\n",
    "            if track_truth[particle][track] == 1:\n",
    "                correct = True\n",
    "    \n",
    "        if correct:\n",
    "            track_exist_truth[particle] += 1\n",
    "    \n",
    "    print(track_exist_truth[11])\n",
    "    \n",
    "    return track_exist_truth\n",
    "\n",
    "track_exist_truth = create_track_exist_truth(particle_start_array=particle_start_array, track_truth=track_truth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Input: a = 3d array, b = 1d array\n",
    "\n",
    "def unison_shuffled_copies2(a, b):\n",
    "    assert a.shape[0] == b.shape[0]\n",
    "    p = np.random.permutation(a.shape[0])\n",
    "    return a[p,:,:], b[p]\n",
    "\n",
    "def create_random_sets2(particle_start_array= particle_start_array, track_exist_truth= track_exist_truth, train_to_total_ratio= 0.9):\n",
    "    #shuffle the dataset\n",
    "    num_examples = particle_start_array.shape[0]\n",
    "    particle_start_array, track_exist_truth = unison_shuffled_copies2(particle_start_array, track_exist_truth)\n",
    "    \n",
    "    #evaluate siye of training and test set and initialize them\n",
    "    train_set_size = np.int(num_examples*train_to_total_ratio)\n",
    "    test_set_size = num_examples - train_set_size\n",
    "    \n",
    "    train_part_start = np.zeros((train_set_size, particle_start_array.shape[1], particle_start_array.shape[2]))\n",
    "    train_track_e_tr = np.zeros((train_set_size))\n",
    "    test_part_start = np.zeros((test_set_size, particle_start_array.shape[1], particle_start_array.shape[2]))\n",
    "    test_track_e_tr = np.zeros((test_set_size))\n",
    "   \n",
    "\n",
    "    #fill train and test sets\n",
    "    for i in range(num_examples):\n",
    "        if train_set_size > i:\n",
    "            train_part_start[i,:,:] += particle_start_array[i,:,:]\n",
    "            train_track_e_tr[i] +=  track_exist_truth[i]\n",
    "        else:\n",
    "            test_part_start[i - train_set_size,:,:]  += particle_start_array[i,:,:]\n",
    "            test_track_e_tr[i - train_set_size] += track_exist_truth[i]\n",
    "                \n",
    "    return train_part_start, train_track_e_tr, test_part_start, test_track_e_tr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.16609355  0.91964991 -0.27518795]\n",
      " [ 0.31077846  0.97898671 -0.26426422]\n",
      " [ 0.75010406  0.65960674 -0.20298509]\n",
      " [ 0.83411023  0.55392795 -0.18595964]] 1.0\n"
     ]
    }
   ],
   "source": [
    "X_train, Y_train, X_test, Y_test = create_random_sets2()\n",
    "print(X_test[1], Y_test[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "lstm_1 (LSTM)                (None, 4, 10)             560       \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch (None, 4, 10)             40        \n",
      "_________________________________________________________________\n",
      "lstm_2 (LSTM)                (None, 10)                840       \n",
      "_________________________________________________________________\n",
      "batch_normalization_2 (Batch (None, 10)                40        \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 100)               1100      \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 1)                 101       \n",
      "=================================================================\n",
      "Total params: 2,681\n",
      "Trainable params: 2,641\n",
      "Non-trainable params: 40\n",
      "_________________________________________________________________\n",
      "None\n",
      "Train on 15057 samples, validate on 1673 samples\n",
      "Epoch 1/500\n",
      "15057/15057 [==============================] - 13s 844us/step - loss: 0.5526 - acc: 0.7584 - val_loss: 0.5406 - val_acc: 0.7639\n",
      "Epoch 2/500\n",
      "15057/15057 [==============================] - 9s 619us/step - loss: 0.5449 - acc: 0.7638 - val_loss: 0.5408 - val_acc: 0.7639\n",
      "Epoch 3/500\n",
      "15057/15057 [==============================] - 9s 620us/step - loss: 0.5437 - acc: 0.7638 - val_loss: 0.5416 - val_acc: 0.7639\n",
      "Epoch 4/500\n",
      "15057/15057 [==============================] - 9s 582us/step - loss: 0.5435 - acc: 0.7638 - val_loss: 0.5403 - val_acc: 0.7639\n",
      "Epoch 5/500\n",
      "15057/15057 [==============================] - 10s 632us/step - loss: 0.5429 - acc: 0.7638 - val_loss: 0.5419 - val_acc: 0.7639\n",
      "Epoch 6/500\n",
      "15057/15057 [==============================] - 9s 608us/step - loss: 0.5425 - acc: 0.7638 - val_loss: 0.5392 - val_acc: 0.7639\n",
      "Epoch 7/500\n",
      "15057/15057 [==============================] - 9s 612us/step - loss: 0.5424 - acc: 0.7638 - val_loss: 0.5405 - val_acc: 0.7639\n",
      "Epoch 8/500\n",
      "15057/15057 [==============================] - 9s 615us/step - loss: 0.5416 - acc: 0.7638 - val_loss: 0.5409 - val_acc: 0.7639\n",
      "Epoch 9/500\n",
      "15057/15057 [==============================] - 9s 616us/step - loss: 0.5420 - acc: 0.7638 - val_loss: 0.5391 - val_acc: 0.7639\n",
      "Epoch 10/500\n",
      "15057/15057 [==============================] - 9s 596us/step - loss: 0.5414 - acc: 0.7638 - val_loss: 0.5413 - val_acc: 0.7639\n",
      "Epoch 11/500\n",
      "15057/15057 [==============================] - 10s 648us/step - loss: 0.5411 - acc: 0.7638 - val_loss: 0.5396 - val_acc: 0.7639\n",
      "Epoch 12/500\n",
      "15057/15057 [==============================] - 9s 607us/step - loss: 0.5413 - acc: 0.7638 - val_loss: 0.5407 - val_acc: 0.7639\n",
      "Epoch 13/500\n",
      "15057/15057 [==============================] - 9s 609us/step - loss: 0.5409 - acc: 0.7638 - val_loss: 0.5390 - val_acc: 0.7639\n",
      "Epoch 14/500\n",
      "15057/15057 [==============================] - 9s 617us/step - loss: 0.5405 - acc: 0.7638 - val_loss: 0.5402 - val_acc: 0.7639\n",
      "Epoch 15/500\n",
      "15057/15057 [==============================] - 10s 634us/step - loss: 0.5407 - acc: 0.7638 - val_loss: 0.5388 - val_acc: 0.7639\n",
      "Epoch 16/500\n",
      "15057/15057 [==============================] - 10s 635us/step - loss: 0.5401 - acc: 0.7638 - val_loss: 0.5377 - val_acc: 0.7639\n",
      "Epoch 17/500\n",
      "15057/15057 [==============================] - 9s 579us/step - loss: 0.5399 - acc: 0.7638 - val_loss: 0.5372 - val_acc: 0.7639\n",
      "Epoch 18/500\n",
      "15057/15057 [==============================] - 9s 601us/step - loss: 0.5402 - acc: 0.7638 - val_loss: 0.5386 - val_acc: 0.7639\n",
      "Epoch 19/500\n",
      "15057/15057 [==============================] - 10s 639us/step - loss: 0.5399 - acc: 0.7638 - val_loss: 0.5376 - val_acc: 0.7639\n",
      "Epoch 20/500\n",
      "15057/15057 [==============================] - 9s 604us/step - loss: 0.5398 - acc: 0.7638 - val_loss: 0.5376 - val_acc: 0.7639\n",
      "Epoch 21/500\n",
      "15057/15057 [==============================] - 9s 604us/step - loss: 0.5402 - acc: 0.7638 - val_loss: 0.5386 - val_acc: 0.7639\n",
      "Epoch 22/500\n",
      "15057/15057 [==============================] - 9s 628us/step - loss: 0.5396 - acc: 0.7638 - val_loss: 0.5410 - val_acc: 0.7639\n",
      "Epoch 23/500\n",
      "15057/15057 [==============================] - 9s 624us/step - loss: 0.5396 - acc: 0.7638 - val_loss: 0.5392 - val_acc: 0.7639\n",
      "Epoch 24/500\n",
      "15057/15057 [==============================] - 9s 625us/step - loss: 0.5395 - acc: 0.7638 - val_loss: 0.5382 - val_acc: 0.7639\n",
      "Epoch 25/500\n",
      "15057/15057 [==============================] - 9s 629us/step - loss: 0.5395 - acc: 0.7638 - val_loss: 0.5379 - val_acc: 0.7639\n",
      "Epoch 26/500\n",
      "15057/15057 [==============================] - 9s 628us/step - loss: 0.5391 - acc: 0.7638 - val_loss: 0.5369 - val_acc: 0.7639\n",
      "Epoch 27/500\n",
      "15057/15057 [==============================] - 9s 606us/step - loss: 0.5388 - acc: 0.7638 - val_loss: 0.5394 - val_acc: 0.7639\n",
      "Epoch 28/500\n",
      "15057/15057 [==============================] - 9s 600us/step - loss: 0.5389 - acc: 0.7638 - val_loss: 0.5406 - val_acc: 0.7639\n",
      "Epoch 29/500\n",
      "15057/15057 [==============================] - 9s 587us/step - loss: 0.5385 - acc: 0.7638 - val_loss: 0.5390 - val_acc: 0.7639\n",
      "Epoch 30/500\n",
      "15057/15057 [==============================] - 9s 627us/step - loss: 0.5389 - acc: 0.7638 - val_loss: 0.5394 - val_acc: 0.7639\n",
      "Epoch 31/500\n",
      "15057/15057 [==============================] - 9s 610us/step - loss: 0.5384 - acc: 0.7638 - val_loss: 0.5405 - val_acc: 0.7639\n",
      "Epoch 32/500\n",
      "15057/15057 [==============================] - 9s 625us/step - loss: 0.5385 - acc: 0.7638 - val_loss: 0.5384 - val_acc: 0.7639\n",
      "Epoch 33/500\n",
      "15057/15057 [==============================] - 9s 577us/step - loss: 0.5377 - acc: 0.7638 - val_loss: 0.5397 - val_acc: 0.7639\n",
      "Epoch 34/500\n",
      "15057/15057 [==============================] - 9s 618us/step - loss: 0.5382 - acc: 0.7638 - val_loss: 0.5413 - val_acc: 0.7639\n",
      "Epoch 35/500\n",
      "15057/15057 [==============================] - 9s 614us/step - loss: 0.5378 - acc: 0.7638 - val_loss: 0.5390 - val_acc: 0.7639\n",
      "Epoch 36/500\n",
      "15057/15057 [==============================] - 9s 610us/step - loss: 0.5377 - acc: 0.7638 - val_loss: 0.5380 - val_acc: 0.7639\n",
      "Epoch 37/500\n",
      "15057/15057 [==============================] - 10s 640us/step - loss: 0.5375 - acc: 0.7638 - val_loss: 0.5408 - val_acc: 0.7639\n",
      "Epoch 38/500\n",
      "15057/15057 [==============================] - 9s 588us/step - loss: 0.5376 - acc: 0.7638 - val_loss: 0.5401 - val_acc: 0.7639\n",
      "Epoch 39/500\n",
      "15057/15057 [==============================] - 10s 637us/step - loss: 0.5363 - acc: 0.7638 - val_loss: 0.5446 - val_acc: 0.7639\n",
      "Epoch 40/500\n",
      "15057/15057 [==============================] - 9s 618us/step - loss: 0.5367 - acc: 0.7638 - val_loss: 0.5400 - val_acc: 0.7639\n",
      "Epoch 41/500\n",
      "15057/15057 [==============================] - 11s 753us/step - loss: 0.5367 - acc: 0.7638 - val_loss: 0.5392 - val_acc: 0.7639\n",
      "Epoch 42/500\n",
      "15057/15057 [==============================] - 12s 783us/step - loss: 0.5356 - acc: 0.7638 - val_loss: 0.5452 - val_acc: 0.7639\n",
      "Epoch 43/500\n",
      "15057/15057 [==============================] - 11s 753us/step - loss: 0.5356 - acc: 0.7638 - val_loss: 0.5418 - val_acc: 0.7639\n",
      "Epoch 44/500\n",
      "15057/15057 [==============================] - 12s 767us/step - loss: 0.5353 - acc: 0.7638 - val_loss: 0.5392 - val_acc: 0.7639\n",
      "Epoch 45/500\n",
      "15057/15057 [==============================] - 11s 763us/step - loss: 0.5354 - acc: 0.7638 - val_loss: 0.5397 - val_acc: 0.7639\n",
      "Epoch 46/500\n",
      "15057/15057 [==============================] - 11s 741us/step - loss: 0.5346 - acc: 0.7638 - val_loss: 0.5435 - val_acc: 0.7639\n",
      "Epoch 47/500\n",
      "15057/15057 [==============================] - 11s 759us/step - loss: 0.5344 - acc: 0.7638 - val_loss: 0.5407 - val_acc: 0.7639\n",
      "Epoch 48/500\n",
      "15057/15057 [==============================] - 12s 773us/step - loss: 0.5343 - acc: 0.7637 - val_loss: 0.5413 - val_acc: 0.7639\n",
      "Epoch 49/500\n",
      "15057/15057 [==============================] - 11s 754us/step - loss: 0.5336 - acc: 0.7638 - val_loss: 0.5419 - val_acc: 0.7639\n",
      "Epoch 50/500\n",
      "15057/15057 [==============================] - 12s 782us/step - loss: 0.5337 - acc: 0.7638 - val_loss: 0.5349 - val_acc: 0.7639\n",
      "Epoch 51/500\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15057/15057 [==============================] - 12s 770us/step - loss: 0.5338 - acc: 0.7638 - val_loss: 0.5374 - val_acc: 0.7639\n",
      "Epoch 52/500\n",
      "15057/15057 [==============================] - 12s 773us/step - loss: 0.5330 - acc: 0.7638 - val_loss: 0.5371 - val_acc: 0.7639\n",
      "Epoch 53/500\n",
      "15057/15057 [==============================] - 12s 775us/step - loss: 0.5321 - acc: 0.7638 - val_loss: 0.5370 - val_acc: 0.7639\n",
      "Epoch 54/500\n",
      "15057/15057 [==============================] - 12s 765us/step - loss: 0.5321 - acc: 0.7638 - val_loss: 0.5296 - val_acc: 0.7639\n",
      "Epoch 55/500\n",
      "15057/15057 [==============================] - 11s 745us/step - loss: 0.5307 - acc: 0.7638 - val_loss: 0.5358 - val_acc: 0.7639\n",
      "Epoch 56/500\n",
      "15057/15057 [==============================] - 11s 698us/step - loss: 0.5296 - acc: 0.7638 - val_loss: 0.5308 - val_acc: 0.7639\n",
      "Epoch 57/500\n",
      "15057/15057 [==============================] - 12s 791us/step - loss: 0.5296 - acc: 0.7638 - val_loss: 0.5303 - val_acc: 0.7639\n",
      "Epoch 58/500\n",
      "15057/15057 [==============================] - 11s 745us/step - loss: 0.5281 - acc: 0.7638 - val_loss: 0.5307 - val_acc: 0.7639\n",
      "Epoch 59/500\n",
      "15057/15057 [==============================] - 11s 754us/step - loss: 0.5271 - acc: 0.7638 - val_loss: 0.5373 - val_acc: 0.7639\n",
      "Epoch 60/500\n",
      "15057/15057 [==============================] - 11s 754us/step - loss: 0.5269 - acc: 0.7638 - val_loss: 0.5294 - val_acc: 0.7639\n",
      "Epoch 61/500\n",
      "15057/15057 [==============================] - 11s 759us/step - loss: 0.5278 - acc: 0.7638 - val_loss: 0.5314 - val_acc: 0.7639\n",
      "Epoch 62/500\n",
      "15057/15057 [==============================] - 11s 731us/step - loss: 0.5255 - acc: 0.7638 - val_loss: 0.5341 - val_acc: 0.7639\n",
      "Epoch 63/500\n",
      "15057/15057 [==============================] - 11s 711us/step - loss: 0.5253 - acc: 0.7638 - val_loss: 0.5335 - val_acc: 0.7639\n",
      "Epoch 64/500\n",
      "15057/15057 [==============================] - 11s 757us/step - loss: 0.5256 - acc: 0.7638 - val_loss: 0.5340 - val_acc: 0.7639\n",
      "Epoch 65/500\n",
      "15057/15057 [==============================] - 11s 715us/step - loss: 0.5249 - acc: 0.7638 - val_loss: 0.5279 - val_acc: 0.7639\n",
      "Epoch 66/500\n",
      "15057/15057 [==============================] - 10s 667us/step - loss: 0.5244 - acc: 0.7638 - val_loss: 0.5304 - val_acc: 0.7639\n",
      "Epoch 67/500\n",
      "15057/15057 [==============================] - 11s 699us/step - loss: 0.5242 - acc: 0.7638 - val_loss: 0.5316 - val_acc: 0.7639\n",
      "Epoch 68/500\n",
      "15057/15057 [==============================] - 11s 749us/step - loss: 0.5224 - acc: 0.7638 - val_loss: 0.5336 - val_acc: 0.7639\n",
      "Epoch 69/500\n",
      "15057/15057 [==============================] - 11s 729us/step - loss: 0.5241 - acc: 0.7638 - val_loss: 0.5362 - val_acc: 0.7639\n",
      "Epoch 70/500\n",
      "15057/15057 [==============================] - 11s 733us/step - loss: 0.5220 - acc: 0.7638 - val_loss: 0.5322 - val_acc: 0.7639\n",
      "Epoch 71/500\n",
      "15057/15057 [==============================] - 11s 761us/step - loss: 0.5217 - acc: 0.7638 - val_loss: 0.5278 - val_acc: 0.7639\n",
      "Epoch 72/500\n",
      "15057/15057 [==============================] - 11s 725us/step - loss: 0.5227 - acc: 0.7638 - val_loss: 0.5290 - val_acc: 0.7639\n",
      "Epoch 73/500\n",
      "15057/15057 [==============================] - 11s 750us/step - loss: 0.5232 - acc: 0.7638 - val_loss: 0.5322 - val_acc: 0.7639\n",
      "Epoch 74/500\n",
      "15057/15057 [==============================] - 11s 727us/step - loss: 0.5223 - acc: 0.7638 - val_loss: 0.5343 - val_acc: 0.7639\n",
      "Epoch 75/500\n",
      "15057/15057 [==============================] - 11s 723us/step - loss: 0.5210 - acc: 0.7638 - val_loss: 0.5329 - val_acc: 0.7639\n",
      "Epoch 76/500\n",
      "15057/15057 [==============================] - 11s 744us/step - loss: 0.5209 - acc: 0.7638 - val_loss: 0.5327 - val_acc: 0.7639\n",
      "Epoch 77/500\n",
      "15057/15057 [==============================] - 11s 715us/step - loss: 0.5215 - acc: 0.7638 - val_loss: 0.5358 - val_acc: 0.7639\n",
      "Epoch 78/500\n",
      "15057/15057 [==============================] - 10s 690us/step - loss: 0.5214 - acc: 0.7638 - val_loss: 0.5292 - val_acc: 0.7639\n",
      "Epoch 79/500\n",
      "15057/15057 [==============================] - 10s 690us/step - loss: 0.5197 - acc: 0.7638 - val_loss: 0.5287 - val_acc: 0.7639\n",
      "Epoch 80/500\n",
      "15057/15057 [==============================] - 11s 740us/step - loss: 0.5208 - acc: 0.7638 - val_loss: 0.5315 - val_acc: 0.7639\n",
      "Epoch 81/500\n",
      "15057/15057 [==============================] - 11s 720us/step - loss: 0.5204 - acc: 0.7638 - val_loss: 0.5313 - val_acc: 0.7639\n",
      "Epoch 82/500\n",
      "15057/15057 [==============================] - 11s 706us/step - loss: 0.5192 - acc: 0.7638 - val_loss: 0.5277 - val_acc: 0.7639\n",
      "Epoch 83/500\n",
      "15057/15057 [==============================] - 10s 678us/step - loss: 0.5189 - acc: 0.7638 - val_loss: 0.5321 - val_acc: 0.7639\n",
      "Epoch 84/500\n",
      "15057/15057 [==============================] - 10s 681us/step - loss: 0.5191 - acc: 0.7638 - val_loss: 0.5258 - val_acc: 0.7639\n",
      "Epoch 85/500\n",
      "15057/15057 [==============================] - 10s 684us/step - loss: 0.5202 - acc: 0.7638 - val_loss: 0.5232 - val_acc: 0.7639\n",
      "Epoch 86/500\n",
      "15057/15057 [==============================] - 11s 716us/step - loss: 0.5195 - acc: 0.7638 - val_loss: 0.5231 - val_acc: 0.7639\n",
      "Epoch 87/500\n",
      "15057/15057 [==============================] - 11s 731us/step - loss: 0.5194 - acc: 0.7638 - val_loss: 0.5297 - val_acc: 0.7639\n",
      "Epoch 88/500\n",
      "15057/15057 [==============================] - 11s 742us/step - loss: 0.5188 - acc: 0.7638 - val_loss: 0.5254 - val_acc: 0.7639\n",
      "Epoch 89/500\n",
      "15057/15057 [==============================] - 11s 720us/step - loss: 0.5197 - acc: 0.7638 - val_loss: 0.5249 - val_acc: 0.7639\n",
      "Epoch 90/500\n",
      "15057/15057 [==============================] - 12s 771us/step - loss: 0.5173 - acc: 0.7638 - val_loss: 0.5258 - val_acc: 0.7639\n",
      "Epoch 91/500\n",
      "15057/15057 [==============================] - 11s 742us/step - loss: 0.5178 - acc: 0.7637 - val_loss: 0.5333 - val_acc: 0.7639\n",
      "Epoch 92/500\n",
      "15057/15057 [==============================] - 11s 722us/step - loss: 0.5191 - acc: 0.7637 - val_loss: 0.5324 - val_acc: 0.7639\n",
      "Epoch 93/500\n",
      "15057/15057 [==============================] - 11s 756us/step - loss: 0.5173 - acc: 0.7638 - val_loss: 0.5207 - val_acc: 0.7639\n",
      "Epoch 94/500\n",
      "15057/15057 [==============================] - 11s 744us/step - loss: 0.5188 - acc: 0.7637 - val_loss: 0.5307 - val_acc: 0.7639\n",
      "Epoch 95/500\n",
      "15057/15057 [==============================] - 12s 775us/step - loss: 0.5174 - acc: 0.7640 - val_loss: 0.5255 - val_acc: 0.7639\n",
      "Epoch 96/500\n",
      "15057/15057 [==============================] - 11s 759us/step - loss: 0.5171 - acc: 0.7638 - val_loss: 0.5247 - val_acc: 0.7639\n",
      "Epoch 97/500\n",
      "15057/15057 [==============================] - 12s 788us/step - loss: 0.5163 - acc: 0.7636 - val_loss: 0.5239 - val_acc: 0.7639\n",
      "Epoch 98/500\n",
      "15057/15057 [==============================] - 11s 740us/step - loss: 0.5179 - acc: 0.7638 - val_loss: 0.5213 - val_acc: 0.7639\n",
      "Epoch 99/500\n",
      "15057/15057 [==============================] - 11s 751us/step - loss: 0.5164 - acc: 0.7639 - val_loss: 0.5270 - val_acc: 0.7633\n",
      "Epoch 100/500\n",
      "15057/15057 [==============================] - 11s 748us/step - loss: 0.5174 - acc: 0.7639 - val_loss: 0.5255 - val_acc: 0.7639\n",
      "Epoch 101/500\n",
      "15057/15057 [==============================] - 11s 717us/step - loss: 0.5176 - acc: 0.7639 - val_loss: 0.5296 - val_acc: 0.7639\n",
      "Epoch 102/500\n",
      "15057/15057 [==============================] - 11s 709us/step - loss: 0.5167 - acc: 0.7638 - val_loss: 0.5216 - val_acc: 0.7639\n",
      "Epoch 103/500\n",
      "15057/15057 [==============================] - 10s 654us/step - loss: 0.5167 - acc: 0.7638 - val_loss: 0.5215 - val_acc: 0.7639\n",
      "Epoch 104/500\n",
      "15057/15057 [==============================] - 10s 652us/step - loss: 0.5158 - acc: 0.7638 - val_loss: 0.5316 - val_acc: 0.7639\n",
      "Epoch 105/500\n",
      "15057/15057 [==============================] - 10s 682us/step - loss: 0.5173 - acc: 0.7638 - val_loss: 0.5247 - val_acc: 0.7639\n",
      "Epoch 106/500\n",
      "15057/15057 [==============================] - 11s 717us/step - loss: 0.5162 - acc: 0.7640 - val_loss: 0.5270 - val_acc: 0.7639\n",
      "Epoch 107/500\n",
      "15057/15057 [==============================] - 10s 680us/step - loss: 0.5181 - acc: 0.7638 - val_loss: 0.5245 - val_acc: 0.7639\n",
      "Epoch 108/500\n",
      "15057/15057 [==============================] - 10s 672us/step - loss: 0.5165 - acc: 0.7638 - val_loss: 0.5243 - val_acc: 0.7639\n",
      "Epoch 109/500\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15057/15057 [==============================] - 10s 693us/step - loss: 0.5165 - acc: 0.7638 - val_loss: 0.5233 - val_acc: 0.7639\n",
      "Epoch 110/500\n",
      "15057/15057 [==============================] - 11s 730us/step - loss: 0.5146 - acc: 0.7639 - val_loss: 0.5233 - val_acc: 0.7639\n",
      "Epoch 111/500\n",
      "15057/15057 [==============================] - 11s 745us/step - loss: 0.5175 - acc: 0.7638 - val_loss: 0.5218 - val_acc: 0.7639\n",
      "Epoch 112/500\n",
      "15057/15057 [==============================] - 11s 742us/step - loss: 0.5159 - acc: 0.7640 - val_loss: 0.5273 - val_acc: 0.7639\n",
      "Epoch 113/500\n",
      "15057/15057 [==============================] - 11s 729us/step - loss: 0.5159 - acc: 0.7635 - val_loss: 0.5263 - val_acc: 0.7639\n",
      "Epoch 114/500\n",
      "15057/15057 [==============================] - 11s 718us/step - loss: 0.5151 - acc: 0.7638 - val_loss: 0.5267 - val_acc: 0.7639\n",
      "Epoch 115/500\n",
      "15057/15057 [==============================] - 10s 643us/step - loss: 0.5147 - acc: 0.7637 - val_loss: 0.5234 - val_acc: 0.7639\n",
      "Epoch 116/500\n",
      "15057/15057 [==============================] - 10s 658us/step - loss: 0.5153 - acc: 0.7636 - val_loss: 0.5237 - val_acc: 0.7639\n",
      "Epoch 117/500\n",
      "15057/15057 [==============================] - 10s 642us/step - loss: 0.5152 - acc: 0.7639 - val_loss: 0.5268 - val_acc: 0.7639\n",
      "Epoch 118/500\n",
      "15057/15057 [==============================] - 10s 688us/step - loss: 0.5162 - acc: 0.7638 - val_loss: 0.5269 - val_acc: 0.7639\n",
      "Epoch 119/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5155 - acc: 0.7640 - val_loss: 0.5214 - val_acc: 0.7639\n",
      "Epoch 120/500\n",
      "15057/15057 [==============================] - 17s 1ms/step - loss: 0.5150 - acc: 0.7640 - val_loss: 0.5276 - val_acc: 0.7639\n",
      "Epoch 121/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5145 - acc: 0.7639 - val_loss: 0.5233 - val_acc: 0.7633\n",
      "Epoch 122/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5155 - acc: 0.7637 - val_loss: 0.5242 - val_acc: 0.7639\n",
      "Epoch 123/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5150 - acc: 0.7636 - val_loss: 0.5187 - val_acc: 0.7639\n",
      "Epoch 124/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5154 - acc: 0.7638 - val_loss: 0.5243 - val_acc: 0.7639\n",
      "Epoch 125/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5142 - acc: 0.7639 - val_loss: 0.5222 - val_acc: 0.7633\n",
      "Epoch 126/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5150 - acc: 0.7638 - val_loss: 0.5274 - val_acc: 0.7633\n",
      "Epoch 127/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5147 - acc: 0.7640 - val_loss: 0.5239 - val_acc: 0.7633\n",
      "Epoch 128/500\n",
      "15057/15057 [==============================] - 15s 1ms/step - loss: 0.5130 - acc: 0.7640 - val_loss: 0.5318 - val_acc: 0.7639\n",
      "Epoch 129/500\n",
      "15057/15057 [==============================] - 15s 1ms/step - loss: 0.5137 - acc: 0.7640 - val_loss: 0.5235 - val_acc: 0.7639\n",
      "Epoch 130/500\n",
      "15057/15057 [==============================] - 15s 1ms/step - loss: 0.5146 - acc: 0.7637 - val_loss: 0.5338 - val_acc: 0.7639\n",
      "Epoch 131/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5136 - acc: 0.7637 - val_loss: 0.5272 - val_acc: 0.7633\n",
      "Epoch 132/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5150 - acc: 0.7639 - val_loss: 0.5219 - val_acc: 0.7639\n",
      "Epoch 133/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5138 - acc: 0.7638 - val_loss: 0.5301 - val_acc: 0.7639\n",
      "Epoch 134/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5150 - acc: 0.7638 - val_loss: 0.5260 - val_acc: 0.7639\n",
      "Epoch 135/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5139 - acc: 0.7637 - val_loss: 0.5257 - val_acc: 0.7639\n",
      "Epoch 136/500\n",
      "15057/15057 [==============================] - 15s 1ms/step - loss: 0.5139 - acc: 0.7638 - val_loss: 0.5249 - val_acc: 0.7633\n",
      "Epoch 137/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5135 - acc: 0.7640 - val_loss: 0.5251 - val_acc: 0.7645\n",
      "Epoch 138/500\n",
      "15057/15057 [==============================] - 14s 914us/step - loss: 0.5134 - acc: 0.7638 - val_loss: 0.5211 - val_acc: 0.7639\n",
      "Epoch 139/500\n",
      "15057/15057 [==============================] - 15s 980us/step - loss: 0.5136 - acc: 0.7636 - val_loss: 0.5215 - val_acc: 0.7639\n",
      "Epoch 140/500\n",
      "15057/15057 [==============================] - 15s 975us/step - loss: 0.5139 - acc: 0.7632 - val_loss: 0.5236 - val_acc: 0.7639\n",
      "Epoch 141/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5138 - acc: 0.7634 - val_loss: 0.5243 - val_acc: 0.7633\n",
      "Epoch 142/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5120 - acc: 0.7641 - val_loss: 0.5228 - val_acc: 0.7639\n",
      "Epoch 143/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5132 - acc: 0.7640 - val_loss: 0.5271 - val_acc: 0.7639\n",
      "Epoch 144/500\n",
      "15057/15057 [==============================] - 15s 983us/step - loss: 0.5130 - acc: 0.7642 - val_loss: 0.5274 - val_acc: 0.7639\n",
      "Epoch 145/500\n",
      "15057/15057 [==============================] - 15s 1ms/step - loss: 0.5127 - acc: 0.7637 - val_loss: 0.5299 - val_acc: 0.7639\n",
      "Epoch 146/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5137 - acc: 0.7640 - val_loss: 0.5331 - val_acc: 0.7639\n",
      "Epoch 147/500\n",
      "15057/15057 [==============================] - 15s 1ms/step - loss: 0.5130 - acc: 0.7643 - val_loss: 0.5258 - val_acc: 0.7615\n",
      "Epoch 148/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5130 - acc: 0.7636 - val_loss: 0.5232 - val_acc: 0.7639\n",
      "Epoch 149/500\n",
      "15057/15057 [==============================] - 15s 1ms/step - loss: 0.5112 - acc: 0.7639 - val_loss: 0.5216 - val_acc: 0.7639\n",
      "Epoch 150/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5131 - acc: 0.7639 - val_loss: 0.5212 - val_acc: 0.7645\n",
      "Epoch 151/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5127 - acc: 0.7636 - val_loss: 0.5280 - val_acc: 0.7627\n",
      "Epoch 152/500\n",
      "15057/15057 [==============================] - 16s 1ms/step - loss: 0.5116 - acc: 0.7640 - val_loss: 0.5227 - val_acc: 0.7639\n",
      "Epoch 153/500\n",
      "15057/15057 [==============================] - 13s 869us/step - loss: 0.5118 - acc: 0.7634 - val_loss: 0.5300 - val_acc: 0.7633\n",
      "Accuracy: 76.39%\n"
     ]
    }
   ],
   "source": [
    "# truncate and pad input sequences\n",
    "max_review_length = 4\n",
    "filepath = \"./keras_model_classifier_LSTM_40_LSTM_40.h5\"\n",
    "\n",
    "callbacks = [\n",
    "    EarlyStopping(monitor='val_loss', patience=30, min_delta=0),\n",
    "    ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True),\n",
    "    History()\n",
    "]\n",
    "\n",
    "#\n",
    "\n",
    "# create the model\n",
    "model = Sequential()\n",
    "#model.add(Dense(12, input_shape=(4,3)))\n",
    "model.add(LSTM(10, return_sequences=True, input_shape=(4,3), activation = 'relu'))\n",
    "model.add(BatchNormalization())\n",
    "model.add(LSTM(10, return_sequences=False, activation = 'relu'))    \n",
    "model.add(BatchNormalization())\n",
    "#model.add(LSTM(40, return_sequences=True, activation = 'relu'))    \n",
    "#model.add(Dropout(0.5))\n",
    "#model.add(LSTM(4, activation = 'relu'))    \n",
    "#model.add(BatchNormalization())\n",
    "model.add(Dense(100, activation='relu'))\n",
    "model.add(Dense(1, activation='sigmoid'))\n",
    "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "print(model.summary())\n",
    "model.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=500, batch_size=50, callbacks= callbacks)\n",
    "model = load_model(filepath)\n",
    "# Final evaluation of the model\n",
    "scores = model.evaluate(X_test, Y_test, verbose=0)\n",
    "print(\"Accuracy: %.2f%%\" % (scores[1]*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def keras_track_classifier(X_train, Y_train, X_test, Y_test, kncells, kcelltype, activation= 'tanh', input_shape=(4,3)\n",
    "                           ,dropout_rate = 0.5, epochs= 500, batch_size = 50):\n",
    "    \n",
    "    \n",
    "    filepath = \"keras_classifier_\" + str(kncells) + str(kcelltype)\n",
    "    \n",
    "    callbacks = [\n",
    "    EarlyStopping(monitor='val_loss', patience=30, min_delta=0),\n",
    "    ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True),\n",
    "    History()\n",
    "    ]\n",
    "    \n",
    "    model = Sequential()\n",
    "    \n",
    "    if activation != 'relu' and activation != 'tanh':\n",
    "        raise ValueError(\"Uknown activation function: {}\".format(activation))\n",
    "    \n",
    "    for layer in range(len(kncells)):\n",
    "        cells = kncells[layer]\n",
    "        \n",
    "        return_seq = False\n",
    "        \n",
    "        if layer < len(kncells):\n",
    "            return_seq = True\n",
    "        \n",
    "        \n",
    "        if layer == 0:\n",
    "            if kcelltype[layer] == \"LSTM\":\n",
    "                model.add(LSTM(kncells[layer], return_sequences=return_seq, input_shape=input_shape,\n",
    "                               activation = activation))\n",
    "            elif  kcelltype[layer] == \"GRU\":\n",
    "                model.add(GRU(kncells[layer], return_sequences=return_seq, input_shape=input_shape,\n",
    "                               activation = activation))\n",
    "            else:\n",
    "                raise ValueError(\"Uknown celltype: {}\".format(kcelltype[layer]))\n",
    "        else:\n",
    "            if kcelltype[layer] == \"LSTM\":\n",
    "                model.add(LSTM(kncells[layer], return_sequences=return_seq,\n",
    "                               activation = activation))\n",
    "            elif  kcelltype[layer] == \"GRU\":\n",
    "                model.add(GRU(kncells[layer], return_sequences=return_seq,\n",
    "                               activation = activation))\n",
    "            else:\n",
    "                raise ValueError(\"Uknown celltype: {}\".format(kcelltype[layer]))\n",
    "            \n",
    "        if dropout_rate != 0:\n",
    "            model.add(Dropout(dropout_rate))\n",
    "    \n",
    "    model.add(Dense(1, activation='sigmoid'))\n",
    "    \n",
    "    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "    \n",
    "    print(model.summary())\n",
    "    \n",
    "    model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, callbacks= callbacks)\n",
    "    \n",
    "    model = load_model(filepath)\n",
    "    \n",
    "    scores = model.evaluate(X_test, Y_test, verbose=0)\n",
    "    \n",
    "    print(\"Accuracy: %.2f%%\" % (scores[1]*100))\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "lstm_3 (LSTM)                (None, 4, 100)            41600     \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 4, 100)            0         \n",
      "_________________________________________________________________\n",
      "gru_1 (GRU)                  (None, 4, 50)             22650     \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 4, 50)             0         \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 4, 1)              51        \n",
      "=================================================================\n",
      "Total params: 64,301\n",
      "Trainable params: 64,301\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Error when checking target: expected dense_3 to have 3 dimensions, but got array with shape (15057, 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-47-2318c693bf5d>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m model = keras_track_classifier(X_train= X_train, Y_train= Y_train, X_test= X_test, Y_test= Y_test,\n\u001b[1;32m----> 5\u001b[1;33m                                kncells= kncells, kcelltype= kcelltype)\n\u001b[0m",
      "\u001b[1;32m<ipython-input-46-65c1ea3596a5>\u001b[0m in \u001b[0;36mkeras_track_classifier\u001b[1;34m(X_train, Y_train, X_test, Y_test, kncells, kcelltype, activation, input_shape, dropout_rate, epochs, batch_size)\u001b[0m\n\u001b[0;32m     53\u001b[0m     \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msummary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     54\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 55\u001b[1;33m     \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mY_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[1;33m=\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     56\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     57\u001b[0m     \u001b[0mmodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\keras\\models.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[0;32m    961\u001b[0m                               \u001b[0minitial_epoch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    962\u001b[0m                               \u001b[0msteps_per_epoch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 963\u001b[1;33m                               validation_steps=validation_steps)\n\u001b[0m\u001b[0;32m    964\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    965\u001b[0m     def evaluate(self, x=None, y=None,\n",
      "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[0;32m   1628\u001b[0m             \u001b[0msample_weight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1629\u001b[0m             \u001b[0mclass_weight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mclass_weight\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1630\u001b[1;33m             batch_size=batch_size)\n\u001b[0m\u001b[0;32m   1631\u001b[0m         \u001b[1;31m# Prepare validation data.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1632\u001b[0m         \u001b[0mdo_validation\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36m_standardize_user_data\u001b[1;34m(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)\u001b[0m\n\u001b[0;32m   1478\u001b[0m                                     \u001b[0moutput_shapes\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1479\u001b[0m                                     \u001b[0mcheck_batch_axis\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1480\u001b[1;33m                                     exception_prefix='target')\n\u001b[0m\u001b[0;32m   1481\u001b[0m         sample_weights = _standardize_sample_weights(sample_weight,\n\u001b[0;32m   1482\u001b[0m                                                      self._feed_output_names)\n",
      "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36m_standardize_input_data\u001b[1;34m(data, names, shapes, check_batch_axis, exception_prefix)\u001b[0m\n\u001b[0;32m    111\u001b[0m                         \u001b[1;34m': expected '\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mnames\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m' to have '\u001b[0m \u001b[1;33m+\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    112\u001b[0m                         \u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m' dimensions, but got array '\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 113\u001b[1;33m                         'with shape ' + str(data_shape))\n\u001b[0m\u001b[0;32m    114\u001b[0m                 \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mcheck_batch_axis\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    115\u001b[0m                     \u001b[0mdata_shape\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata_shape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mValueError\u001b[0m: Error when checking target: expected dense_3 to have 3 dimensions, but got array with shape (15057, 1)"
     ]
    }
   ],
   "source": [
    "kncells = [100, 50]\n",
    "kcelltype = ['LSTM', 'GRU']\n",
    "\n",
    "model = keras_track_classifier(X_train= X_train, Y_train= Y_train, X_test= X_test, Y_test= Y_test,\n",
    "                               kncells= kncells, kcelltype= kcelltype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pkl.dump( scalor, open(\"scalor.pkl\" , \"wb\" ) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filepath = \"./keras_model_classifier_LSTM_40_LSTM_40.h5\"\n",
    "\n",
    "model.save(filepath)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = model.evaluate(X_test, Y_test, verbose=0)\n",
    "print(\"Accuracy: %.2f%%\" % (scores[1]*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = reshapor_inv(X_train)\n",
    "\n",
    "X_test = reshapor_inv(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fit model no training data\n",
    "model = XGBClassifier(max_depth=5, n_estimators=1000, learning_rate=0.05).fit(X_train, Y_train, verbose = 0)\n",
    "\n",
    "predictions = model.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = 0\n",
    "\n",
    "for prediction in predictions:\n",
    "    if prediction == 1:\n",
    "        c += 1\n",
    "\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = [round(value) for value in predictions]\n",
    "\n",
    "# evaluate predictions\n",
    "accuracy = accuracy_score(Y_test, predictions)\n",
    "print(\"Accuracy: %.2f%%\" % (accuracy * 100.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}