Newer
Older
rnn_bachelor_thesis / 1_to_1_multi_layer.ipynb
@saslie saslie on 29 Apr 2018 28 KB Test on second laptop
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "from sklearn import preprocessing\n",
    "import pickle as pkl\n",
    "from pathlib import Path\n",
    "\n",
    "#import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import data as array\n",
    "# 8 hits with x,y,z\n",
    "\n",
    "testset = pd.read_pickle('matched_8hittracks.pkl')\n",
    "#print(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Check testset with arbitrary particle\n",
    "\n",
    "tset = np.array(testset)\n",
    "tset = tset.astype('float32')\n",
    "#print(tset.shape)\n",
    "#for i in range(8):\n",
    "    #print(tset[1,3*i:(3*i+3)])\n",
    "#print(tset[0,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n",
    "\n",
    "def reshapor(arr_orig):\n",
    "    timesteps = int(arr_orig.shape[1]/3)\n",
    "    number_examples = int(arr_orig.shape[0])\n",
    "    arr = np.zeros((number_examples, timesteps, 3))\n",
    "    \n",
    "    for i in range(number_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n",
    "        \n",
    "    return arr\n",
    "\n",
    "def reshapor_inv(array_shaped):\n",
    "    timesteps = int(array_shaped.shape[1])\n",
    "    num_examples = int(array_shaped.shape[0])\n",
    "    arr = np.zeros((num_examples, timesteps*3))\n",
    "    \n",
    "    for i in range(num_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n",
    "        \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### create the training set and the test set###\n",
    "\n",
    "def create_random_sets(dataset, train_to_total_ratio):\n",
    "    #shuffle the dataset\n",
    "    num_examples = dataset.shape[0]\n",
    "    p = np.random.permutation(num_examples)\n",
    "    dataset = dataset[p,:]\n",
    "    \n",
    "    #evaluate siye of training and test set and initialize them\n",
    "    train_set_size = np.int(num_examples*train_to_total_ratio)\n",
    "    test_set_size = num_examples - train_set_size\n",
    "    \n",
    "    train_set = np.zeros((train_set_size, dataset.shape[1]))\n",
    "    test_set = np.zeros((test_set_size, dataset.shape[1]))\n",
    "   \n",
    "\n",
    "    #fill train and test sets\n",
    "    for i in range(num_examples):\n",
    "        if train_set_size > i:\n",
    "            train_set[i,:] += dataset[i,:]\n",
    "        else:\n",
    "            test_set[i - train_set_size,:]  += dataset[i,:]\n",
    "                \n",
    "    return train_set, test_set\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set, test_set = create_random_sets(tset, 0.99)\n",
    "\n",
    "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n",
    "#print(test_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_min_max_scaler(arr, feature_range= (-1,1)):\n",
    "    min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = min_max_scalor.fit_transform(arr)\n",
    "    return min_max_scalor\n",
    "\n",
    "min_max_scalor = set_min_max_scaler(train_set)\n",
    "\n",
    "\n",
    "#transform data\n",
    "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(min_max_scalor.transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = min_max_scalor.transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = min_max_scalor.inverse_transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.nverse_transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn - Standard scaler\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_std_scaler(arr):\n",
    "    std_scalor = preprocessing.StandardScaler()\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(std_scalor.fit(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = std_scalor.fit(arr)\n",
    "    return std_scalor\n",
    "\n",
    "std_scalor = set_std_scaler(train_set)\n",
    "\n",
    "#transform data\n",
    "def std_scaler(arr, std_scalor= std_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(std_scalor.transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = std_scalor.transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def std_scaler_inv(arr, std_scalor= std_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(std_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = std_scalor.inverse_transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#reshape the data\n",
    "\n",
    "train_set = reshapor(train_set)\n",
    "test_set = reshapor(test_set)\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Scale data either with MinMax scaler or with Standard scaler\n",
    "#Return scalor if fit = True and and scaled array otherwise\n",
    "\n",
    "def scaler(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n",
    "    \n",
    "    if scalerfunc == \"std\":\n",
    "        arr = std_scaler(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler(arr, min_max_scalor= min_max_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n",
    "\n",
    "def scaler_inv(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n",
    "\n",
    "    if scalerfunc == \"std\":\n",
    "        arr = std_scaler_inv(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler_inv(arr, min_max_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scale the data\n",
    "\n",
    "func = \"minmax\"\n",
    "\n",
    "train_set = scaler(train_set, scalerfunc = func)\n",
    "test_set = scaler(test_set, scalerfunc = func)\n",
    "\n",
    "if func == \"minmax\":\n",
    "    scalor = min_max_scalor\n",
    "elif func == \"std\":\n",
    "    scalor = std_scalor\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###create random mini_batches###\n",
    "\n",
    "\n",
    "def unison_shuffled_copies(a, b):\n",
    "    assert a.shape[0] == b.shape[0]\n",
    "    p = np.random.permutation(a.shape[0])\n",
    "    return a[p,:,:], b[p,:,:]\n",
    "\n",
    "def random_mini_batches(inputt, target, minibatch_size = 500):\n",
    "    \n",
    "    num_examples = inputt.shape[0]\n",
    "    \n",
    "    \n",
    "    #Number of complete batches\n",
    "    \n",
    "    number_of_batches = int(num_examples/minibatch_size)\n",
    "    minibatches = []\n",
    "   \n",
    "    #shuffle particles\n",
    "    _i, _t = unison_shuffled_copies(inputt, target)\n",
    "    #print(_t.shape)\n",
    "        \n",
    "    \n",
    "    for i in range(number_of_batches):\n",
    "        \n",
    "        minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatches.append((minibatch_train, minibatch_true))\n",
    "        \n",
    "        \n",
    "    minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n",
    "    \n",
    "    \n",
    "    return minibatches\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create random minibatches of train and test set with input and target array\n",
    "\n",
    "\n",
    "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n",
    "#_train, _target = minibatches[0]\n",
    "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n",
    "#print(train[0,:,:], target[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#minibatches = random_mini_batches(inputt_train, target_train)\n",
    "\n",
    "\n",
    "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n",
    "\n",
    "#print(len(minibatches))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNPlacePrediction():\n",
    "    \n",
    "    \n",
    "    def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\", scalor= scalor):\n",
    "        \n",
    "        self.nsteps = time_steps\n",
    "        self.future_steps = future_steps\n",
    "        self.ninputs = ninputs\n",
    "        self.ncells = ncells\n",
    "        self.num_output = num_output\n",
    "        self._ = cell_type #later used to create folder name\n",
    "        self.__ = activation #later used to create folder name\n",
    "        self.loss_list = []\n",
    "        self.scalor = scalor\n",
    "        \n",
    "        #### The input is of shape (num_examples, time_steps, ninputs)\n",
    "        #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n",
    "        self.X = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n",
    "        self.Y = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n",
    "\n",
    "        \n",
    "        #Check if activation function valid and set activation\n",
    "        if self.__==\"relu\":\n",
    "            self.activation = tf.nn.relu\n",
    "            \n",
    "        elif self.__==\"tanh\":\n",
    "            self.activation = tf.nn.tanh\n",
    "                    \n",
    "        elif self.__==\"leaky_relu\":\n",
    "            self.activation = tf.nn.leaky_relu\n",
    "            \n",
    "        elif self.__==\"elu\":\n",
    "            self.activation = tf.nn.elu\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn avtivation function: {}\".format(self.__))\n",
    "        \n",
    "        \n",
    "        \n",
    "        #Check if cell type valid and set cell_type\n",
    "        if self._==\"basic_rnn\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicRNNCell\n",
    "            \n",
    "        elif self._==\"lstm\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicLSTMCell\n",
    "                    \n",
    "        elif self._==\"GRU\":\n",
    "            self.cell_type = tf.contrib.rnn.GRUCell\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn cell type: {}\".format(self._))\n",
    "            \n",
    "        \n",
    "        #Check Input of ncells        \n",
    "        if (type(self.ncells) == int):\n",
    "            self.ncells = [self.ncells]\n",
    "        \n",
    "        if (type(self.ncells) != list):\n",
    "            raise ValueError(\"Wrong type of Input for ncells\")\n",
    "        \n",
    "        for _ in range(len(self.ncells)):\n",
    "            if type(self.ncells[_]) != int:\n",
    "                raise ValueError(\"Wrong type of Input for ncells\")\n",
    "                \n",
    "        self.activationlist = []\n",
    "        for _ in range(len(self.ncells)-1):\n",
    "            self.activationlist.append(self.activation)\n",
    "        self.activationlist.append(tf.nn.tanh)\n",
    "        \n",
    "        self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n",
    "                                                 for layer in range(len(self.ncells))])\n",
    "            \n",
    "        \n",
    "        #### I now define the output\n",
    "        self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        self.sess = tf.Session()\n",
    "        \n",
    "    def set_cost_and_functions(self, LR=0.001):\n",
    "        #### I define here the function that unrolls the RNN cell\n",
    "        self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n",
    "        #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n",
    "        self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output))   \n",
    "        \n",
    "        #### the rest proceed as usual\n",
    "        self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n",
    "        #### Variable initializer\n",
    "        self.init = tf.global_variables_initializer()\n",
    "        self.saver = tf.train.Saver()\n",
    "        self.sess.run(self.init)\n",
    "  \n",
    "    \n",
    "    def save(self, rnn_folder=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.save(self.sess, rnn_folder)       \n",
    "            \n",
    "            \n",
    "    def load(self, filename=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.restore(self.sess, filename)\n",
    "\n",
    "               \n",
    "        \n",
    "    def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n",
    "        patience_cnt = 0\n",
    "        start = len(self.loss_list)\n",
    "        epoche_save = start\n",
    "        \n",
    "        folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n",
    "        \n",
    "        for iep in range(start, start + epochs):\n",
    "            loss = 0\n",
    "            \n",
    "            batches = len(minibatches)\n",
    "            #Here I iterate over the batches\n",
    "            for batch in range(batches):\n",
    "            #### Here I train the RNNcell\n",
    "            #### The X is the time series, the Y is shifted by 1 time step\n",
    "                train, target = minibatches[batch]\n",
    "                self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n",
    "                \n",
    "            \n",
    "                loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n",
    "            \n",
    "            #Normalize loss over number of batches and scale it back before normaliziation\n",
    "            loss /= batches\n",
    "            self.loss_list.append(loss)\n",
    "            \n",
    "            #print(loss)\n",
    "            \n",
    "            #Here I create the checkpoint if the perfomance is better\n",
    "            if iep > 1 and iep%checkpoint == 0 and self.loss_list[iep] < self.loss_list[epoche_save]:\n",
    "                #print(\"Checkpoint created at epoch: \", iep)\n",
    "                self.save(folder)\n",
    "                epoche_save = iep\n",
    "            \n",
    "            #early stopping with patience\n",
    "            if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 1.5/10**7:\n",
    "                patience_cnt += 1\n",
    "                #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n",
    "                \n",
    "                if patience_cnt + 1 > patience:\n",
    "                    print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n",
    "                    print(\"Cost: \",loss)\n",
    "                    break\n",
    "            \n",
    "            #Note that the loss here is multiplied with 1000 for easier reading\n",
    "            if iep%print_step==0:\n",
    "                print(\"Epoch number \",iep)\n",
    "                print(\"Cost: \",loss*10**6, \"e-6\")\n",
    "                print(\"Patience: \",patience_cnt, \"/\", patience)\n",
    "                print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n",
    "        \n",
    "        #Set model back to the last checkpoint if performance was better\n",
    "        if self.loss_list[epoche_save] < self.loss_list[iep]:\n",
    "            self.load(folder)\n",
    "            print(\"\\n\")\n",
    "            print(\"State of last checkpoint checkpoint at epoch \", epoche_save, \" restored\")\n",
    "            print(\"Performance at last checkpoint is \" ,(self.loss_list[iep] - self.loss_list[epoche_save])/self.loss_list[iep]*100, \"% better\" )\n",
    "        \n",
    "        folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "        self.save(folder)\n",
    "        print(\"\\n\")\n",
    "        print(\"Model saved in at: \", folder)\n",
    "            \n",
    "            \n",
    "        \n",
    "    def predict(self, x):\n",
    "        return self.sess.run(self.output, feed_dict={self.X:x})\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def full_save(rnn):\n",
    "    folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "    rnn.save(folder)\n",
    "    pkl_name = folder[2:-10] + \".pkl\"\n",
    "    \n",
    "    \n",
    "    pkl_dic = {\"ncells\": rnn.ncells,\n",
    "              \"ninputs\": rnn.ninputs,\n",
    "              \"future_steps\": rnn.future_steps,\n",
    "              \"nsteps\": rnn.nsteps,\n",
    "              \"num_output\": rnn.num_output,\n",
    "              \"cell_type\": rnn._, #cell_type\n",
    "              \"activation\": rnn.__, #Activation\n",
    "              \"loss_list\": rnn.loss_list,\n",
    "              \"scalor\": rnn.scalor}\n",
    "    pkl.dump( pkl_dic, open(pkl_name , \"wb\" ) )\n",
    "\n",
    "\n",
    "\n",
    "def full_load(folder):\n",
    "    #Directory of okl file\n",
    "    pkl_name = folder[2:-10] + \".pkl\"\n",
    "    \n",
    "    #Check if pkl file exists\n",
    "    my_file = Path(pkl_name)\n",
    "    if my_file.is_file() == False:\n",
    "        raise ValueError(\"There is no .pkl file with the name: {}\".format(pkl_name))\n",
    "        \n",
    "    pkl_dic = pkl.load( open(pkl_name , \"rb\" ) )\n",
    "    ncells = pkl_dic[\"ncells\"]\n",
    "    ninputs = pkl_dic[\"ninputs\"]\n",
    "    scalor = pkl_dic[\"scalor\"]\n",
    "    future_steps = pkl_dic[\"future_steps\"]\n",
    "    timesteps = pkl_dic[\"nsteps\"] \n",
    "    num_output = pkl_dic[\"num_output\"]\n",
    "    cell_type = pkl_dic[\"cell_type\"]\n",
    "    activation = pkl_dic[\"activation\"]\n",
    "\n",
    "    tf.reset_default_graph()\n",
    "    rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=cell_type, activation=activation, scalor=scalor)\n",
    "\n",
    "    rnn.set_cost_and_functions()\n",
    "    \n",
    "    rnn.load(folder)\n",
    "    \n",
    "    rnn.loss_list = pkl_dic[\"loss_list\"]\n",
    "    \n",
    "    return rnn\n",
    "\n",
    "def get_rnn_folder(ncells, cell_type, activation):\n",
    "    folder = \"./rnn_model_\" + cell_type + \"_\" + activation + \"_\" + str(ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "    return folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps = 7\n",
    "future_steps = 1\n",
    "\n",
    "ninputs = 3\n",
    "\n",
    "#ncells as int or list of int\n",
    "ncells = [50, 40, 30, 20, 10]\n",
    "\n",
    "num_output = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn.set_cost_and_functions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rnn.fit(minibatches, epochs = 5, print_step=5)\n",
    "full_save(rnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#Plot the loss\n",
    "def plot_loss_list(loss_list= rnn.loss_list):\n",
    "    plt.plot(rnn.loss_list)\n",
    "    plt.xlabel(\"Epoch\")\n",
    "    plt.ylabel(\"Cost\")\n",
    "    plt.show()\n",
    "\n",
    "plot_loss_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_save(rnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#folder = get_rnn_folder(ncells = [50, 40, 30, 20, 10], cell_type = \"lstm\", activation = \"leaky_relu\")\n",
    "#rnn = full_load(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###test_input.shape###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here I predict based on my test set\n",
    "\n",
    "test_pred = rnn.predict(test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n",
    "\n",
    "#scaler_inv(test_input, scalerfunc = func)[0,:,:]\n",
    "\n",
    "\n",
    "diff = scaler_inv(test_pred, scalerfunc = func)-scaler_inv(test_target, scalerfunc = func )\n",
    "\n",
    "print(diff[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here I evaluate my model on the test set based on mean_squared_error\n",
    "\n",
    "print(\"Loss on test set:\", rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}