Newer
Older
rnn_bachelor_thesis / 1_to_1_multi_layer.ipynb
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "from sklearn import preprocessing\n",
    "import pickle as pkl\n",
    "from pathlib import Path\n",
    "\n",
    "#import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import data as array\n",
    "# 8 hits with x,y,z\n",
    "\n",
    "testset = pd.read_pickle('matched_8hittracks.pkl')\n",
    "#print(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test': 1, 'a': 'b'}\n"
     ]
    }
   ],
   "source": [
    "dic = {\"test\": 1, \"a\": \"b\"}\n",
    "pkl.dump( dic, open( \"save.pkl\", \"wb\" ) )\n",
    "print(pkl.load( open( \"save.pkl\", \"rb\" ) ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Check testset with arbitrary particle\n",
    "\n",
    "tset = np.array(testset)\n",
    "tset = tset.astype('float32')\n",
    "#print(tset.shape)\n",
    "#for i in range(8):\n",
    "    #print(tset[1,3*i:(3*i+3)])\n",
    "#print(tset[0,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n",
    "\n",
    "def reshapor(arr_orig):\n",
    "    timesteps = int(arr_orig.shape[1]/3)\n",
    "    number_examples = int(arr_orig.shape[0])\n",
    "    arr = np.zeros((number_examples, timesteps, 3))\n",
    "    \n",
    "    for i in range(number_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n",
    "        \n",
    "    return arr\n",
    "\n",
    "def reshapor_inv(array_shaped):\n",
    "    timesteps = int(array_shaped.shape[1])\n",
    "    num_examples = int(array_shaped.shape[0])\n",
    "    arr = np.zeros((num_examples, timesteps*3))\n",
    "    \n",
    "    for i in range(num_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n",
    "        \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "### create the training set and the test set###\n",
    "\n",
    "def create_random_sets(dataset, train_to_total_ratio):\n",
    "    #shuffle the dataset\n",
    "    num_examples = dataset.shape[0]\n",
    "    p = np.random.permutation(num_examples)\n",
    "    dataset = dataset[p,:]\n",
    "    \n",
    "    #evaluate siye of training and test set and initialize them\n",
    "    train_set_size = np.int(num_examples*train_to_total_ratio)\n",
    "    test_set_size = num_examples - train_set_size\n",
    "    \n",
    "    train_set = np.zeros((train_set_size, dataset.shape[1]))\n",
    "    test_set = np.zeros((test_set_size, dataset.shape[1]))\n",
    "   \n",
    "\n",
    "    #fill train and test sets\n",
    "    for i in range(num_examples):\n",
    "        if train_set_size > i:\n",
    "            train_set[i,:] += dataset[i,:]\n",
    "        else:\n",
    "            test_set[i - train_set_size,:]  += dataset[i,:]\n",
    "                \n",
    "    return train_set, test_set\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set, test_set = create_random_sets(tset, 0.99)\n",
    "\n",
    "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n",
    "#print(test_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_min_max_scaler(arr, feature_range= (-1,1)):\n",
    "    min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = min_max_scalor.fit_transform(arr)\n",
    "    return min_max_scalor\n",
    "\n",
    "min_max_scalor = set_min_max_scaler(train_set)\n",
    "\n",
    "\n",
    "#transform data\n",
    "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(min_max_scalor.transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = min_max_scalor.transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = min_max_scalor.inverse_transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.nverse_transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn - Standard scaler\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_std_scaler(arr):\n",
    "    std_scalor = preprocessing.StandardScaler()\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(std_scalor.fit(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = std_scalor.fit(arr)\n",
    "    return std_scalor\n",
    "\n",
    "std_scalor = set_std_scaler(train_set)\n",
    "\n",
    "#transform data\n",
    "def std_scaler(arr, std_scalor= std_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(std_scalor.transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = std_scalor.transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def std_scaler_inv(arr, std_scalor= std_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(std_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = std_scalor.inverse_transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#reshape the data\n",
    "\n",
    "train_set = reshapor(train_set)\n",
    "test_set = reshapor(test_set)\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Scale data either with MinMax scaler or with Standard scaler\n",
    "#Return scalor if fit = True and and scaled array otherwise\n",
    "\n",
    "def scaler(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n",
    "    \n",
    "    if scalerfunc == \"std\":\n",
    "        arr = std_scaler(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler(arr, min_max_scalor= min_max_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n",
    "\n",
    "def scaler_inv(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n",
    "\n",
    "    if scalerfunc == \"std\":\n",
    "        arr = std_scaler_inv(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler_inv(arr, min_max_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scale the data\n",
    "\n",
    "func = \"minmax\"\n",
    "\n",
    "train_set = scaler(train_set, scalerfunc = func)\n",
    "test_set = scaler(test_set, scalerfunc = func)\n",
    "\n",
    "if func == \"minmax\":\n",
    "    scalor = min_max_scalor\n",
    "elif func == \"std\":\n",
    "    scalor = std_scalor\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "###create random mini_batches###\n",
    "\n",
    "\n",
    "def unison_shuffled_copies(a, b):\n",
    "    assert a.shape[0] == b.shape[0]\n",
    "    p = np.random.permutation(a.shape[0])\n",
    "    return a[p,:,:], b[p,:,:]\n",
    "\n",
    "def random_mini_batches(inputt, target, minibatch_size = 500):\n",
    "    \n",
    "    num_examples = inputt.shape[0]\n",
    "    \n",
    "    \n",
    "    #Number of complete batches\n",
    "    \n",
    "    number_of_batches = int(num_examples/minibatch_size)\n",
    "    minibatches = []\n",
    "   \n",
    "    #shuffle particles\n",
    "    _i, _t = unison_shuffled_copies(inputt, target)\n",
    "    #print(_t.shape)\n",
    "        \n",
    "    \n",
    "    for i in range(number_of_batches):\n",
    "        \n",
    "        minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatches.append((minibatch_train, minibatch_true))\n",
    "        \n",
    "        \n",
    "    minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n",
    "    \n",
    "    \n",
    "    return minibatches\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create random minibatches of train and test set with input and target array\n",
    "\n",
    "\n",
    "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n",
    "#_train, _target = minibatches[0]\n",
    "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n",
    "#print(train[0,:,:], target[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#minibatches = random_mini_batches(inputt_train, target_train)\n",
    "\n",
    "\n",
    "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n",
    "\n",
    "#print(len(minibatches))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNPlacePrediction():\n",
    "    \n",
    "    \n",
    "    def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\", scalor= scalor):\n",
    "        \n",
    "        self.nsteps = time_steps\n",
    "        self.future_steps = future_steps\n",
    "        self.ninputs = ninputs\n",
    "        self.ncells = ncells\n",
    "        self.num_output = num_output\n",
    "        self._ = cell_type #later used to create folder name\n",
    "        self.__ = activation #later used to create folder name\n",
    "        self.loss_list = []\n",
    "        self.scalor = scalor\n",
    "        \n",
    "        #### The input is of shape (num_examples, time_steps, ninputs)\n",
    "        #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n",
    "        self.X = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n",
    "        self.Y = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n",
    "\n",
    "        \n",
    "        #Check if activation function valid and set activation\n",
    "        if self.__==\"relu\":\n",
    "            self.activation = tf.nn.relu\n",
    "            \n",
    "        elif self.__==\"tanh\":\n",
    "            self.activation = tf.nn.tanh\n",
    "                    \n",
    "        elif self.__==\"leaky_relu\":\n",
    "            self.activation = tf.nn.leaky_relu\n",
    "            \n",
    "        elif self.__==\"elu\":\n",
    "            self.activation = tf.nn.elu\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn avtivation function: {}\".format(self.__))\n",
    "        \n",
    "        \n",
    "        \n",
    "        #Check if cell type valid and set cell_type\n",
    "        if self._==\"basic_rnn\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicRNNCell\n",
    "            \n",
    "        elif self._==\"lstm\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicLSTMCell\n",
    "                    \n",
    "        elif self._==\"GRU\":\n",
    "            self.cell_type = tf.contrib.rnn.GRUCell\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn cell type: {}\".format(self._))\n",
    "            \n",
    "        \n",
    "        #Check Input of ncells        \n",
    "        if (type(self.ncells) == int):\n",
    "            self.ncells = [self.ncells]\n",
    "        \n",
    "        if (type(self.ncells) != list):\n",
    "            raise ValueError(\"Wrong type of Input for ncells\")\n",
    "        \n",
    "        for _ in range(len(self.ncells)):\n",
    "            if type(self.ncells[_]) != int:\n",
    "                raise ValueError(\"Wrong type of Input for ncells\")\n",
    "                \n",
    "        self.activationlist = []\n",
    "        for _ in range(len(self.ncells)-1):\n",
    "            self.activationlist.append(self.activation)\n",
    "        self.activationlist.append(tf.nn.tanh)\n",
    "        \n",
    "        self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n",
    "                                                 for layer in range(len(self.ncells))])\n",
    "            \n",
    "        \n",
    "        #### I now define the output\n",
    "        self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        self.sess = tf.Session()\n",
    "        \n",
    "    def set_cost_and_functions(self, LR=0.001):\n",
    "        #### I define here the function that unrolls the RNN cell\n",
    "        self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n",
    "        #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n",
    "        self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output))   \n",
    "        \n",
    "        #### the rest proceed as usual\n",
    "        self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n",
    "        #### Variable initializer\n",
    "        self.init = tf.global_variables_initializer()\n",
    "        self.saver = tf.train.Saver()\n",
    "        self.sess.run(self.init)\n",
    "  \n",
    "    \n",
    "    def save(self, rnn_folder=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.save(self.sess, rnn_folder)       \n",
    "            \n",
    "            \n",
    "    def load(self, filename=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.restore(self.sess, filename)\n",
    "\n",
    "               \n",
    "        \n",
    "    def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n",
    "        patience_cnt = 0\n",
    "        start = len(self.loss_list)\n",
    "        epoche_save = start\n",
    "        \n",
    "        folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n",
    "        \n",
    "        for iep in range(start, start + epochs):\n",
    "            loss = 0\n",
    "            \n",
    "            batches = len(minibatches)\n",
    "            #Here I iterate over the batches\n",
    "            for batch in range(batches):\n",
    "            #### Here I train the RNNcell\n",
    "            #### The X is the time series, the Y is shifted by 1 time step\n",
    "                train, target = minibatches[batch]\n",
    "                self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n",
    "                \n",
    "            \n",
    "                loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n",
    "            \n",
    "            #Normalize loss over number of batches and scale it back before normaliziation\n",
    "            loss /= batches\n",
    "            self.loss_list.append(loss)\n",
    "            \n",
    "            #print(loss)\n",
    "            \n",
    "            #Here I create the checkpoint if the perfomance is better\n",
    "            if iep > 1 and iep%checkpoint == 0 and self.loss_list[iep] < self.loss_list[epoche_save]:\n",
    "                #print(\"Checkpoint created at epoch: \", iep)\n",
    "                self.save(folder)\n",
    "                epoche_save = iep\n",
    "            \n",
    "            #early stopping with patience\n",
    "            if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 1.5/10**7:\n",
    "                patience_cnt += 1\n",
    "                #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n",
    "                \n",
    "                if patience_cnt + 1 > patience:\n",
    "                    print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n",
    "                    print(\"Cost: \",loss)\n",
    "                    break\n",
    "            \n",
    "            #Note that the loss here is multiplied with 1000 for easier reading\n",
    "            if iep%print_step==0:\n",
    "                print(\"Epoch number \",iep)\n",
    "                print(\"Cost: \",loss*10**6, \"e-6\")\n",
    "                print(\"Patience: \",patience_cnt, \"/\", patience)\n",
    "                print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n",
    "        \n",
    "        #Set model back to the last checkpoint if performance was better\n",
    "        if self.loss_list[epoche_save] < self.loss_list[iep]:\n",
    "            self.load(folder)\n",
    "            print(\"\\n\")\n",
    "            print(\"State of last checkpoint checkpoint at epoch \", epoche_save, \" restored\")\n",
    "            print(\"Performance at last checkpoint is \" ,(self.loss_list[iep] - self.loss_list[epoche_save])/self.loss_list[iep]*100, \"% better\" )\n",
    "        \n",
    "        folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "        self.save(folder)\n",
    "        print(\"\\n\")\n",
    "        print(\"Model saved in at: \", folder)\n",
    "            \n",
    "            \n",
    "        \n",
    "    def predict(self, x):\n",
    "        return self.sess.run(self.output, feed_dict={self.X:x})\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def full_save(rnn):\n",
    "    folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "    rnn.save(folder)\n",
    "    pkl_name = folder[2:-10] + \".pkl\"\n",
    "    \n",
    "    \n",
    "    pkl_dic = {\"ncells\": rnn.ncells,\n",
    "              \"ninputs\": rnn.ninputs,\n",
    "              \"future_steps\": rnn.future_steps,\n",
    "              \"nsteps\": rnn.nsteps,\n",
    "              \"num_output\": rnn.num_output,\n",
    "              \"cell_type\": rnn._, #cell_type\n",
    "              \"activation\": rnn.__, #Activation\n",
    "              \"loss_list\": rnn.loss_list,\n",
    "              \"scalor\": rnn.scalor}\n",
    "    pkl.dump( pkl_dic, open(pkl_name , \"wb\" ) )\n",
    "\n",
    "\n",
    "\n",
    "def full_load(folder):\n",
    "    #Directory of okl file\n",
    "    pkl_name = folder[2:-10] + \".pkl\"\n",
    "    \n",
    "    #Check if pkl file exists\n",
    "    my_file = Path(pkl_name)\n",
    "    if my_file.is_file() == False:\n",
    "        raise ValueError(\"There is no .pkl file with the name: {}\".format(pkl_name))\n",
    "        \n",
    "    pkl_dic = pkl.load( open(pkl_name , \"rb\" ) )\n",
    "    ncells = pkl_dic[\"ncells\"]\n",
    "    ninputs = pkl_dic[\"ninputs\"]\n",
    "    scalor = pkl_dic[\"scalor\"]\n",
    "    future_steps = pkl_dic[\"future_steps\"]\n",
    "    timesteps = pkl_dic[\"nsteps\"] \n",
    "    num_output = pkl_dic[\"num_output\"]\n",
    "    cell_type = pkl_dic[\"cell_type\"]\n",
    "    activation = pkl_dic[\"activation\"]\n",
    "\n",
    "    tf.reset_default_graph()\n",
    "    rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=cell_type, activation=activation, scalor=scalor)\n",
    "\n",
    "    rnn.set_cost_and_functions()\n",
    "    \n",
    "    rnn.load(folder)\n",
    "    \n",
    "    rnn.loss_list = pkl_dic[\"loss_list\"]\n",
    "    \n",
    "    return rnn\n",
    "\n",
    "def get_rnn_folder(ncells, cell_type, activation):\n",
    "    folder = \"./rnn_model_\" + cell_type + \"_\" + activation + \"_\" + str(len(ncells)) + \"l_\" + str(ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "    return folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps = 7\n",
    "future_steps = 1\n",
    "\n",
    "ninputs = 3\n",
    "\n",
    "#ncells as int or list of int\n",
    "ncells = [50, 40, 30, 20, 10]\n",
    "\n",
    "num_output = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn.set_cost_and_functions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch number  5\n",
      "Cost:  138389.9095210623 e-6\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  5 \n",
      "\n",
      "\n",
      "\n",
      "Model saved in at:  ./rnn_model_lstm_leaky_relu_[50,40,30,20,10]c/rnn_basic\n"
     ]
    }
   ],
   "source": [
    "rnn.fit(minibatches, epochs = 5, print_step=5)\n",
    "full_save(rnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd81eX5//HXlcHekLDDDEtkhshQcaCCSqBOcGGrUgdqtbbV1l+ttrbWfrVaxYGjilVxt9FarQuUoRC2TCGMxIAEUKaEjOv3Rw40YiBBcvicnPN+Ph7nwfmskytHk3fuz33u+zZ3R0RE5FDigi5AREQin8JCREQqpLAQEZEKKSxERKRCCgsREamQwkJERCqksBARkQopLEREpEIKCxERqVBC0AVUlWbNmnn79u2DLkNEpFqZO3fuZndPqui8qAmL9u3bk5WVFXQZIiLVipmtq8x5ug0lIiIVUliIiEiFFBYiIlIhhYWIiFRIYSEiIhVSWIiISIUUFiIiUqGYD4uSEuePby9j/ZbdQZciIhKxYj4s1m7ZxZTZ6znzb5/wrwVfBl2OiEhEivmw6JhUj7dvPIGuLepz45QF/PLVhezeWxR0WSIiESXmwwKgTeM6vDR+IBNO7swrc3MZ+dB0luZtD7osEZGIobAISYiP45YzuvL8FcexY08Rox+ZwbMz1+LuQZcmIhI4hcUBBnduxn9uPIEhnZpyR+YSxj83l2927w26LBGRQCksytG0Xk2evnwAt5/VnakrNjHiwU+YvWZr0GWJiARGYXEQZsaVJ3Tk9WuGUDMhjjGTZvHA+yspLtFtKRGJPQqLChzbpiFv3XACo/u05oH3v+CiJz5lw7Zvgy5LROSoUlhUQr2aCdx/YR/uO783i7/cxogHP+G9pV8FXZaIyFGjsDgM5/Zvw1vXH0/rRrW5anIWv8tcwp7C4qDLEhEJO4XFYeqYVI/Xrx3MT4Z04JmZaznnkZmszt8ZdFkiImGlsPgBaibE89uRPXhqXBobtn3LyIem80pWjsZkiEjUUlgcgVO7N+c/N55IrzYN+cWri/jZSwvYsacw6LJERKpcWMPCzIab2QozW2Vmt5Zz/GozW2xmC8xsupn1CO1vb2bfhvYvMLPHwlnnkWjRsBbPXzmQn5/WhTcX5nH2Q9NZlPtN0GWJiFSpsIWFmcUDE4ERQA9g7L4wKOMFdz/W3fsA9wL3lzm22t37hB5Xh6vOqhAfZ1x/aiov/XQQhUUlnPvoTJ74OJsSjckQkSgRzpZFOrDK3bPdfS8wBRhV9gR3LztbX12gWv92HdC+CW/feAKndEvm7reX8eNn5rB5Z0HQZYmIHLFwhkVrIKfMdm5o33eY2XVmtprSlsUNZQ51MLP5ZjbNzE4IY51VqlGdGjx2SX9+P7ons7K3MOLBT5ixanPQZYmIHJFwhoWVs+97LQd3n+junYBfAbeHdm8AUty9L3Az8IKZNfjeFzAbb2ZZZpaVn59fhaUfGTPj0oHt+Nd1Q2hYO5FLnvqMe99ZTmFxSdCliYj8IOEMi1ygbZntNkDeIc6fAowGcPcCd98Sej4XWA10OfACd5/k7mnunpaUlFRlhVeV7i0bkDlhCBemteWRqau54PFZ5GzV8q0iUv2EMyzmAKlm1sHMagBjgMyyJ5hZapnNs4AvQvuTQh3kmFlHIBXIDmOtYVOnRgL3nNuLh8b2ZdVXOznzb5/w70Ubgi5LROSwhC0s3L0ImAC8CywDXnb3JWZ2l5llhE6bYGZLzGwBpbebxoX2nwgsMrOFwKvA1e5erecIH9m7FW/feAKdkupx3QvzuO31xXy7V1OFiEj1YNEy6jgtLc2zsrKCLqNChcUl3PfflTw2bTVdmtfjobH96NqiftBliUiMMrO57p5W0XkawX2UJcbHceuIbjx3RTpbdxWS8fB0nv9snaYKEZGIprAIyAmpSfznxhNI79CE37zxOdc+P49tuzVViIhEJoVFgJLq1+TZH6dz24huvLf0K8782yes27Ir6LJERL5HYRGwuDjjp0M78eo1g9m1t4grn83SZIQiEnEUFhGiT9tGPHJRP7I37+KmlxZoXikRiSgKiwgyuHMz7hjZg/eXbeK+91YEXY6IyH4JQRcg33XpwHYs27CDiR+tpkvz+ozq873ptEREjjq1LCKMmXFnxjGkt2/CL19dxOLcbUGXJCKisIhENRLiePSSfjSrV5OrJmexaceeoEsSkRinsIhQTevV5InL0tj2bSE/fW4uBUWaGkREgqOwiGA9WjXg/gt6M3/9N/zmjc81yltEAqOwiHAjjm3Jz4al8urcXJ6avibockQkRiksqoEbTkllRM8W/PHtZUxbGTmLPIlI7FBYVANxccZ9F/Sma4sGTHhhHtn5O4MuSURijMKimqhTI4EnLutPjfg4rpycxbZvNSWIiBw9CotqpE3jOjx6SX/Wb9nNDS/Op1hTgojIUaKwqGbSOzTh96N7Mm1lPn9+Z3nQ5YhIjNB0H9XQ2PQUlm/YzqSPs+nSvD7n9W8TdEkiEuXUsqimbj+7B4M7NeXXry9m3vqvgy5HRKKcwqKaSoyPY+JF/WjRsBY/fW4uG7Z9G3RJIhLFFBbVWOO6NXhyXBq7C4oYP3kuewo1JYiIhIfCoprr0rw+D47py+d52/jlq4s0JYiIhIXCIgoM69GcW07vSubCPB6dtjrockQkCoU1LMxsuJmtMLNVZnZrOcevNrPFZrbAzKabWY8yx24LXbfCzM4IZ53R4NqTOjGydyv+8u4K3l/6VdDliEiUCVtYmFk8MBEYAfQAxpYNg5AX3P1Yd+8D3AvcH7q2BzAGOAYYDjwSej05CDPj3nN70bNVQ26cMp+VX+0IuiQRiSLhbFmkA6vcPdvd9wJTgFFlT3D37WU26wL7briPAqa4e4G7rwFWhV5PDqF2jXgmXdaf2jUSuPLZLL7etTfokkQkSoQzLFoDOWW2c0P7vsPMrjOz1ZS2LG44zGvHm1mWmWXl52s2VoCWDWvz+KX92bhtD9e9MI/C4pKgSxKRKBDOsLBy9n3vozruPtHdOwG/Am4/zGsnuXuau6clJSUdUbHRpH+7xvzxnGOZuXoLd/97WdDliEgUCOd0H7lA2zLbbYC8Q5w/BXj0B14rBzivfxuWb9jOk9PX0LVFfcampwRdkohUY+FsWcwBUs2sg5nVoLTDOrPsCWaWWmbzLOCL0PNMYIyZ1TSzDkAqMDuMtUalW0d048QuSfz2X58ze83WoMsRkWosbGHh7kXABOBdYBnwsrsvMbO7zCwjdNoEM1tiZguAm4FxoWuXAC8DS4F3gOvcXcOTD1NCfBwPje1L28Z1uOYfc8n9enfQJYlINWXRMuI3LS3Ns7Kygi4jIq3O38noiTNo07gOr10ziDo1NNmwiJQys7nunlbReRrBHQM6JdXj4Yv6sWLjdn7+8kJKtGiSiBwmhUWMGNoliV+f2Z3/fL6Rhz5cFXQ5IlLN6H5EDLni+A4s27CDv76/kq4t6jG8Z8ugSxKRakItixhiZtz9o570TWnETS8tZGne9oovEhFBYRFzaiXG8/gl/WlYO5GrJmexZWdB0CWJSDWgsIhByQ1qMemy/mzeWcA1z89jb5GmBBGRQ1NYxKhebRpx73m9mL1mK3dkLtGiSSJySOrgjmGj+rRmxcYdPDJ1Nd1b1ueyQe2DLklEIpRaFjHultO7Mqx7Mne+uZSZqzYHXY6IRCiFRYyLizP+emEfOjary7UvzGPdll1BlyQiEUhhIdSvlciT49JwhyufzWLHnsKgSxKRCKOwEADaNa3LIxf3I3vzLm56aQFFWjRJRMpQWMh+Qzo343cje/D+sk3c/PJCBYaI7KdPQ8l3XDqoPTsLivnzO8sxg/vO701CvP6mEIl1Cgv5nmtO6oTj3PvOCgDuv6AP8XHlrXQrIrFCYSHluvakzrjDX95dgQH3KTBEYprCQg7qupM7A6HAMOP/zu+twBCJUQoLOaTrTu6Mu/N//12JAX9RYIjEJIWFVGjCKam4w33vrQSDv5ynwBCJNQoLqZTrT03FgfvfW4lh3HteLwWGSAxRWEil3XBqaQvjr++vxAz+fK4CQyRWKCzksNw4LBXHeeD9LzBKAyNOgSES9cI62srMhpvZCjNbZWa3lnP8ZjNbamaLzOwDM2tX5lixmS0IPTLDWaccnp8N68KNp6byytxcfvXaIkpKtBaGSLQLW8vCzOKBicBpQC4wx8wy3X1pmdPmA2nuvtvMrgHuBS4MHfvW3fuEqz45Mjed1gUH/vbBF5jBPeeohSESzcJ5GyodWOXu2QBmNgUYBewPC3f/qMz5nwKXhLEeqWI3DUsFd/724SoM40/nHKvAEIlS4QyL1kBOme1c4LhDnH8F8J8y27XMLAsoAu5x938eeIGZjQfGA6SkpBxxwXJ4zGx/C+OhD1dhBn/8kQJDJBqFMyzK+41R7s1tM7sESAOGltmd4u55ZtYR+NDMFrv76u+8mPskYBJAWlqabpwHwMy4+bQuuMPDH5UGxt2jFRgi0SacYZELtC2z3QbIO/AkMxsG/AYY6u4F+/a7e17o32wzmwr0BVYfeL0Ez8z4+eldcJyJH60GjLtH91RgiESRcIbFHCDVzDoAXwJjgIvKnmBmfYHHgeHuvqnM/sbAbncvMLNmwBBKO78lQpkZt5zeFXd4ZOpqzOAPoxQYItEibGHh7kVmNgF4F4gHnnb3JWZ2F5Dl7pnAX4B6wCtmBrDe3TOA7sDjZlZC6cd77zngU1QSgcyMX5zRFQcenboaA36vwBCJCmEdlOfubwNvH7Dvt2WeDzvIdTOBY8NZm4SHmfHLM0pbGI9NK21h/H5UT0J/DIhINaUR3FLlzIxfDe+K4zw+LRtQYIhUdwoLCQsz49bh3cDh8Y+zMYy7Rh2jwBCpphQWEjZmxq0juuHApI+ziTP4XYYCQ6Q6UlhIWJkZt43ohrvzxCdrMDPuGNlDgSFSzSgsJOzMjF+f2R13eHL6GgAFhkg1U6lZZ83sucrsEzkYM+M3Z3XniuM78MzMtdz11lLcNehepLqobMvimLIboRll+1d9ORLNzIzbz+pOiTt/n7EWw/h/Z3dXC0OkGjhkWJjZbcCvgdpmtn3fbmAvoTmZRA6HmfHbs3sA8PSMNZjB7WcpMEQi3SHDwt3/BPzJzP7k7rcdpZokyu0LDHd4avoaDPiNAkMkolX2NtRbZlbX3XeFZojtBzzo7uvCWJtEsX2fioLSTm8z+PWZCgyRSFXZsHgU6G1mvYFfAk8Bk/nulOIih2VfYJT9WO1tI7opMEQiUGXDosjd3cxGUdqieMrMxoWzMIkNZsbvMo7ZP3DPgFsVGCIRp7JhsSPU2X0pcELo01CJ4StLYomZcWfGMXhoahAMbh2uwBCJJJUNiwspXYviJ+6+0cxSKJ1eXKRKmJXOHbVv8kGjdDJCBYZIZKhUWIQC4nlggJmdDcx298nhLU1ijZlxV0bP/dObl7hz6/BuWg9DJAJUdgT3BcBs4HzgAuAzMzsvnIVJbIqLM34/qieXDmzHpI+zufofc9lZUBR0WSIxr1JhQeka2QPcfZy7XwakA/8vfGVJLIuLK70ldcfIHnywfBPnPDKDdVt2BV2WSEyrbFjElV0jG9hyGNeKHDYz48dDOjD5J+l8tb2AjIdnMP2LzUGXJRKzKvsL/x0ze9fMLjezy4F/c8ByqSLhMKRzMzInDKF5g5qM+/tsnp6+RhMQigTgkGFhZp3NbIi7/wJ4HOgF9AZmobmh5Chp17Qur187hFO7JXPXW0v55auLKCgqDroskZhSUcviAWAHgLu/7u43u/tNlLYqHgh3cSL71KuZwGOX9OeGU1N5ZW4uYyZ9yqbte4IuSyRmVBQW7d190YE73T0LaB+WikQOIi7OuPm0Ljx6cT9WbNzByIenszDnm6DLEokJFYVFrUMcq13Ri5vZcDNbYWarzOzWco7fbGZLzWyRmX1gZu3KHBtnZl+EHppaRPYbcWxLXrtmMInxcZz/+CzemJ8bdEkiUa+isJhjZlcduNPMrgDmHurC0JQgE4ERQA9grJn1OOC0+UCau/cCXgXuDV3bBLgDOI7Sj+neYWaNK/52JFZ0b9mAzAnH0y+lETe9tJA/vr2M4hJ1fIuES0UjuH8GvGFmF/O/cEgDagA/quDadGCVu2cDmNkUYBSwdN8J7v5RmfM/BS4JPT8DeM/dt4aufQ8YDrxY0TcksaNJ3Ro8d8Vx/OGtpUz6OJtlG7bz8Nh+NKyjactEqtohWxbu/pW7DwbuBNaGHne6+yB331jBa7cGcsps54b2HcwVwH9+4LUSoxLj47hzVE/uOedYPs3ewuhHZrBq046gyxKJOpUaZ+HuH7n7Q6HHh5V87fIm9Cn3PkFoQaU0/jc5YaWuNbPxZpZlZln5+fmVLEui0Zj0FF68aiA79hQyeuJMPlj2VdAliUSVcI7CzgXaltluA+QdeJKZDaN0OpEMdy84nGvdfZK7p7l7WlJSUpUVLtVTWvsmZE44nvbN6nDl5CwmfrRKA/hEqkg4w2IOkGpmHcysBjAGyCx7gpn1pXSwX8YB04m8C5xuZo1DHdunh/aJHFKrRrV55aeDGdmrFX95dwXXvzifb/dqAJ/IkarsehaHzd2LzGwCpb/k44Gn3X2Jmd0FZLl7JqW3neoBr4TWLVjv7hnuvtXMfk9p4ADcta+zW6QitWvE8+CYPvRo1YA/v7OcNZt3MemyNFo3qvDT3iJyEBYtzfS0tDTPysoKugyJMB8t38QNL86nRkIcj17Sn/QOTYIuSSSimNlcd0+r6DzNHCtR7eRuybxx3RAa1k7koic+5fnP1gVdkki1pLCQqNc5uR5vXDeEIZ2b8Zs3Puf2fy6msLgk6LJEqhWFhcSEhrUTefryAfx0aEf+8el6Ln7yM7bsLKj4QhEBFBYSQ+LjjNtGdOeBC/uwMOcbMh6ewZK8bUGXJVItKCwk5ozu25pXrh5EcYlz3qOz+PeiDUGXJBLxFBYSk3q1aUTm9UPo3rI+170wj/v+u4ISTUQoclAKC4lZyfVr8eL4gVyY1paHPlzF+Oey2LGnMOiyRCKSwkJiWs2EeO4591juzDiGj1bkc84jM1m7eVfQZYlEHIWFxDwzY9zg9jz3k3TydxaQ8fB0Pl6piSlFylJYiIQM7tyMzOuOp2XD2lz+99k8+Um2JiIUCVFYiJSR0rQOr187mNN6NOcP/17GLa8sYk+hJiIUUViIHKBuzQQevbg/PxuWymvzchk9cQbz1n8ddFkigVJYiJQjLs742bAuPDUujW92F3LuozO5/Z+L2a5PS0mMUliIHMKp3Zvz3s0nMm5Qe57/bD3D7pvG24s3qC9DYo7CQqQC9Wsl8ruMY/jntUNIql+Ta5+fxxXPZpGzdXfQpYkcNQoLkUrq3bYR/7puCLef1Z1Zq7dw+l8/ZtLHqynSDLYSAxQWIochIT6OK0/oyHs3n8jgTk3549vLGfnwDBbkfBN0aSJhpbAQ+QHaNK7Dk+PSeOySfmzdVcCPHpnB7zKXaLoQiVoKC5EfyMwY3rMl7988lMsGtuPZWWsZdv803vlcHeASfRQWIkeofq1E7hzVkzeuHUKTujW5+h/zuGpyFl9+823QpYlUGYWFSBXp07YRb04Ywm/O7M6MVVs47f5pPPlJtjrAJSooLESqUEJ8HFed2JH/3nQix3Vowh/+vYzRj8xgca5W5JPqLaxhYWbDzWyFma0ys1vLOX6imc0zsyIzO++AY8VmtiD0yAxnnSJVrW2TOjx9+QAmXtSPr7YXMGridO58cwk7C4qCLk3kB0kI1wubWTwwETgNyAXmmFmmuy8tc9p64HLglnJe4lt37xOu+kTCzcw4q1dLTujSjL+8s4JnZq7lnc83cmfGMZx+TIugyxM5LOFsWaQDq9w92933AlOAUWVPcPe17r4I0E1diVoNaiXy+9E9ee2awTSsncj45+YyfnIWeeoAl2oknGHRGsgps50b2ldZtcwsy8w+NbPRVVuayNHXL6Uxb15/PLeO6MbHX+Rz2v3TeHr6Goq19rdUA+EMCytn3+H8VKS4expwEfCAmXX63hcwGx8KlKz8fK1sJpEvMT6Oq4d24r2bhpLWvgl3vbWU0RNn8PmX6gCXyBbOsMgF2pbZbgPkVfZid88L/ZsNTAX6lnPOJHdPc/e0pKSkI6tW5Chq26QOz/x4AA+N7cuGbXvIeHg6f3hrKbvUAS4RKpxhMQdINbMOZlYDGANU6lNNZtbYzGqGnjcDhgBLD32VSPViZozs3YoPfj6UsekpPDl9DafdP433l34VdGki3xO2sHD3ImAC8C6wDHjZ3ZeY2V1mlgFgZgPMLBc4H3jczJaELu8OZJnZQuAj4J4DPkUlEjUa1k7k7h8dy2vXDKJ+rUSunJzF1c/NZeO2PUGXJrKfRcscNmlpaZ6VlRV0GSJHpLC4hCc+yebB978gMT6OX5zRlUsGtiM+rrwuQJEjZ2ZzQ/3Dh6QR3CIRJDE+jmtP6sx/bzqRvimNuCNzCec8OpMleeoAl2ApLEQiULumdZn8k3QeHNOHL7/eTcbDM/j5ywtZlKt1MyQYYRvBLSJHxswY1ac1J3VJ5q/vr+TlrBxem5dLn7aNGDe4HWce25KaCfFBlykxQn0WItXE9j2FvD43l8mz1pG9eRdN69ZgbHoKFw9MoWXD2kGXJ9VUZfssFBYi1UxJiTNj9WaenbmOD5Z/RZwZZxzTnMsGtee4Dk0wU2e4VF5lw0K3oUSqmbg444TUJE5ITSJn627+8dk6XpqTw9uLN9K1eX0uG9yO0X1aU7emfryl6qhlIRIF9hQWk7kwj2dnrmVJ3nbq10rg/P5tuXRQOzo0qxt0eRLBdBtKJAa5O/PWf8OzM9fy9uINFJU4Q7skMW5wO4Z2SdZ4DfkehYVIjNu0fQ8vzs7h+c/WsWlHASlN6nDpwHacn9aGRnVqBF2eRAiFhYgApaPC312ykckz1zF77VZqJcYxuk9rLhvUnh6tGgRdngRMYSEi37M0bzvPfbqWN+Z/yZ7CEga0b8xlg9ozvGcLEuM1RjcWKSxE5KC27S7klbk5TJ61jvVbd5NcvyYXHZfCRekpJDeoFXR5chQpLESkQiUlzrSV+Tw7ay1TV+STEGeMOLYl4wa1o3+7xhqzEQM0zkJEKhQXZ5zcLZmTuyWzdvMunvt0HS9n5fDmwjx6tGzAuMHtyOjdmto1NK1IrFPLQkS+Y/feIv45P4/Js9ayfOMOGtZO5MIBbbnkuHakNK0TdHlSxXQbSkSOiLsze81WJs9axztLNlLizildk7l4YIrGbEQR3YYSkSNiZhzXsSnHdWzKxm17eOGzdbwwO4cPnsmiVcNaXDCgLRektaVVI01iGAvUshCRSttbVMIHy77ixTk5fPJFPgac3DWZsekpnNQ1iQR9/Lba0W0oEQmrnK27eWlODi9l5ZC/o4DmDWpyYVpbLhjQljaN1bdRXSgsROSoKCwu4cPlm3hx9nqmrcwHYGiXJMamp3BKt2QN9otwCgsROepyv97Ny1m5vDwnh43b95Bcvybnp7VhzIAU2jZRayMSKSxEJDBFxSVMXZHPi7PX89GKTThwfOdmXJSewrAezdXaiCCVDYuw/hczs+FmtsLMVpnZreUcP9HM5plZkZmdd8CxcWb2RegxLpx1ikjVSoiPY1iP5jx1+QCm/+oUbjw1ldWbdnLN8/MY9KcPuec/y1m7eVfQZcphCFvLwszigZXAaUAuMAcY6+5Ly5zTHmgA3AJkuvurof1NgCwgDXBgLtDf3b8+2NdTy0IkshWXOB+vzOeF2ev5cPkmikucIZ2bMjY9hdN7tKBGglobQYiEcRbpwCp3zw4VNAUYBewPC3dfGzpWcsC1ZwDvufvW0PH3gOHAi2GsV0TCKL7M1CIbt+3hlawcpszJYcIL82latwbn9W/DhQPa0jGpXtClSjnCGRatgZwy27nAcUdwbesqqktEAtaiYS2uPzWVa0/uzPRVm3nxs/U8OX0Nj3+czcCOTRibnsLwni2omaA5qSJFOMOivLkAKnvPq1LXmtl4YDxASkpK5SsTkYgQH2cM7ZLE0C5JbNq+h1fm5jJlznpunLKAxnUSObdfG8akp9A5Wa2NoIXzJmEu0LbMdhsgryqvdfdJ7p7m7mlJSUk/uFARCV5yg1pcd3Jnpt1yMv+44jgGd2rGMzPXMuz+aVzw2CzemJ/LnsLioMuMWeHs4E6gtIP7VOBLSju4L3L3JeWc+wzw1gEd3HOBfqFT5lHawb31YF9PHdwi0Sd/RwGvzctlyuz1rN2ym4a1EzmnX2vGpqfQpXn9oMuLChExzsLMzgQeAOKBp939bjO7C8hy90wzGwC8ATQG9gAb3f2Y0LU/AX4deqm73f3vh/paCguR6FVS4nyavYUXZq/n3SUbKSx2urdsQEbvVpzdq6UG/B2BiAiLo0lhIRIbtuws4F8L8nhzUR7z138DQL+URozs3YqzerUkub6WhT0cCgsRiXo5W3fz5qI8MhfksXzjDuIMBnZsSkbvVozo2ZKGdRKDLjHiKSxEJKZ88dUO3lyYR+bCPNZu2U1ivHFiahIZfVoxrHtz6tbU8j3lUViISExydz7/cjuZC7/krUUb2LBtD7US4zi1e3MyerdiaJckaiVq/MY+CgsRiXklJU7Wuq/JXPglby/eyNZde6lfM4EzerZgZO9WDOnUNOYXbFJYiIiUUVRcwozVW3hzYR7vfr6RHQVFNK1bgxHHtiCjd2vS2jUmLgbXFVdYiIgcxJ7CYqatzCdzYR4fLPuKPYUltGxYi7N7tSSjd2t6tm6AWWwEh8JCRKQSdhUU8f6yr8hckMfHX+RTWOx0aFaXkb1aMrJ3K1KjfPCfwkJE5DB9s3sv73y+kTcX5TFr9RZKHLq1qE9Gn1aM7NUqKgf/KSxERI7Aph17+PeiDby5MI95ocF/fVMaMbJX6ajx5AbRMfhPYSEiUkVytu7mrUUbyFyYx7IN2zGDgR2acnbvlpzcNZlWjWoHXeIPprAQEQmDVZt2kLmwtMWxJrQ0bGpyPU7qmsRJXZNJa9+4Wq3DobAQEQkjd2fVpp1+ZNhkAAAHQ0lEQVRMXZHP1JWbmL1mK4XFTp0a8Qzu1JShXZM5qUtSxPdzRMKyqiIiUcvMSG1en9Tm9bnqxI7sKihi1uotTF25iakr8nl/2SYAOiXVZWiXZE7qmkR6hybVdvS4WhYiIlXM3Vmdv4tpK/OZumITn63Zyt6iEmonxjOoU1NO6lq6OmC7pnWDLlUtCxGRoJgZnZPr0Tm5Hlcc34Hde4v4NHtL6S2rFfl8uLy01dGhWV2GdknipK5JDOzYNKJbHWpZiIgcZWs272LqitLbVZ9mb6GgqISaCXEM7Nh0f0d5h2ZHp9WhDm4RkWpgT2Hx/lbHtJX5+z9h1a5pnf2tjkEdm1G7RnhaHQoLEZFqaN2WfX0d+cxcvZk9hSXUSIjjuA5NQuGRTKekulU2d5XCQkSkmttTWMzsNVv3d5Svzi9tdbRpXDvUSZ7M4E5Nj2hhJ4WFiEiUydm6m6kr85kWanXs3ltMjfg4Tj+mOQ9f1O8HvaY+DSUiEmXaNqnDpQPbcenAdhQUFZO19mumrthEjYTwL+CksBARqYZqJsQzpHMzhnRudlS+XljjyMyGm9kKM1tlZreWc7ymmb0UOv6ZmbUP7W9vZt+a2YLQ47Fw1ikiIocWtpaFmcUDE4HTgFxgjplluvvSMqddAXzt7p3NbAzwZ+DC0LHV7t4nXPWJiEjlhbNlkQ6scvdsd98LTAFGHXDOKODZ0PNXgVMtVtYyFBGpRsIZFq2BnDLbuaF95Z7j7kXANqBp6FgHM5tvZtPM7IQw1ikiIhUIZwd3eS2EAz+ne7BzNgAp7r7FzPoD/zSzY9x9+3cuNhsPjAdISUmpgpJFRKQ84WxZ5AJty2y3AfIOdo6ZJQANga3uXuDuWwDcfS6wGuhy4Bdw90nunubuaUlJSWH4FkREBMIbFnOAVDPrYGY1gDFA5gHnZALjQs/PAz50dzezpFAHOWbWEUgFssNYq4iIHELYbkO5e5GZTQDeBeKBp919iZndBWS5eybwFPCcma0CtlIaKAAnAneZWRFQDFzt7lvDVauIiBxa1Ez3YWb5wLojeIlmwOYqKqe603vxXXo/vkvvx/9Ew3vRzt0rvI8fNWFxpMwsqzLzo8QCvRffpffju/R+/E8svRfhn1BERESqPYWFiIhUSGHxP5OCLiCC6L34Lr0f36X3439i5r1Qn4WIiFRILQsREalQzIdFRdOoxxIza2tmH5nZMjNbYmY3Bl1T0MwsPjRH2VtB1xI0M2tkZq+a2fLQ/yODgq4pSGZ2U+jn5HMze9HMagVdUzjFdFiUmUZ9BNADGGtmPYKtKlBFwM/dvTswELguxt8PgBuBZUEXESEeBN5x925Ab2L4fTGz1sANQJq796R04PGYQ19VvcV0WFC5adRjhrtvcPd5oec7KP1lcOBMwTHDzNoAZwFPBl1L0MysAaUzKzwF4O573f2bYKsKXAJQOzSvXR2+P/ddVIn1sKjMNOoxKbRqYV/gs2ArCdQDwC+BkqALiQAdgXzg76Hbck+aWd2giwqKu38J/B+wntJZsre5+3+DrSq8Yj0sKjONeswxs3rAa8DPDpwWPlaY2dnAptCsx1L6V3Q/4FF37wvsAmK2j8/MGlN6F6ID0Aqoa2aXBFtVeMV6WFRmGvWYYmaJlAbF8+7+etD1BGgIkGFmaym9PXmKmf0j2JIClQvkuvu+luarlIZHrBoGrHH3fHcvBF4HBgdcU1jFelhUZhr1mBFa0vYpYJm73x90PUFy99vcvY27t6f0/4sP3T2q/3I8FHffCOSYWdfQrlOBpQGWFLT1wEAzqxP6uTmVKO/wD+dKeRHvYNOoB1xWkIYAlwKLzWxBaN+v3f3tAGuSyHE98HzoD6ts4McB1xMYd//MzF4F5lH6KcL5RPlobo3gFhGRCsX6bSgREakEhYWIiFRIYSEiIhVSWIiISIUUFiIiUiGFhchhMLNiM1tQ5lFlo5jNrL2ZfV5VrydSlWJ6nIXID/Ctu/cJugiRo00tC5EqYGZrzezPZjY79Ogc2t/OzD4ws0Whf1NC+5ub2RtmtjD02DdVRLyZPRFaJ+G/ZlY7sG9KpAyFhcjhqX3AbagLyxzb7u7pwMOUzlhL6Plkd+8FPA/8LbT/b8A0d+9N6RxL+2YOSAUmuvsxwDfAuWH+fkQqRSO4RQ6Dme1093rl7F8LnOLu2aHJGDe6e1Mz2wy0dPfC0P4N7t7MzPKBNu5eUOY12gPvuXtqaPtXQKK7/yH835nIoallIVJ1/CDPD3ZOeQrKPC9G/YoSIRQWIlXnwjL/zgo9n8n/ltu8GJgeev4BcA3sX+e7wdEqUuSH0F8tIoendpkZeaF0Tep9H5+taWafUfpH2NjQvhuAp83sF5SuNLdvptYbgUlmdgWlLYhrKF1xTSQiqc9CpAqE+izS3H1z0LWIhINuQ4mISIXUshARkQqpZSEiIhVSWIiISIUUFiIiUiGFhYiIVEhhISIiFVJYiIhIhf4/SfYvlkTM4lUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#Plot the loss\n",
    "\n",
    "plt.plot(rnn.loss_list)\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Cost\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#full_save(rnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm_leaky_relu_5l_[50,40,30,20,10]c/rnn_basic\n"
     ]
    }
   ],
   "source": [
    "folder = get_rnn_folder(ncells = [50, 40, 30, 20, 10], cell_type = \"lstm\", activation = \"leaky_relu\")\n",
    "rnn = full_load(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###test_input.shape###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here I predict based on my test set\n",
    "\n",
    "test_pred = rnn.predict(test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n",
    "\n",
    "#scaler_inv(test_input, scalerfunc = func)[0,:,:]\n",
    "\n",
    "\n",
    "diff = scaler_inv(test_pred, scalerfunc = func)-scaler_inv(test_target, scalerfunc = func )\n",
    "\n",
    "print(diff[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here I evaluate my model on the test set based on mean_squared_error\n",
    "\n",
    "print(\"Loss on test set:\", rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}