Newer
Older
rnn_bachelor_thesis / 1_to_1_multi_layer.ipynb
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "from sklearn import preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import data as array\n",
    "# 8 hits with x,y,z\n",
    "\n",
    "testset = pd.read_pickle('matched_8hittracks.pkl')\n",
    "#print(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Check testset with arbitrary particle\n",
    "\n",
    "tset = np.array(testset)\n",
    "tset = tset.astype('float32')\n",
    "#print(tset.shape)\n",
    "#for i in range(8):\n",
    "    #print(tset[1,3*i:(3*i+3)])\n",
    "#print(tset[0,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n",
    "\n",
    "def reshapor(arr_orig):\n",
    "    timesteps = int(arr_orig.shape[1]/3)\n",
    "    number_examples = int(arr_orig.shape[0])\n",
    "    arr = np.zeros((number_examples, timesteps, 3))\n",
    "    \n",
    "    for i in range(number_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n",
    "        \n",
    "    return arr\n",
    "\n",
    "def reshapor_inv(array_shaped):\n",
    "    timesteps = int(array_shaped.shape[1])\n",
    "    num_examples = int(array_shaped.shape[0])\n",
    "    arr = np.zeros((num_examples, timesteps*3))\n",
    "    \n",
    "    for i in range(num_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n",
    "        \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "### create the training set and the test set###\n",
    "\n",
    "def create_random_sets(dataset, train_to_total_ratio):\n",
    "    #shuffle the dataset\n",
    "    num_examples = dataset.shape[0]\n",
    "    p = np.random.permutation(num_examples)\n",
    "    dataset = dataset[p,:]\n",
    "    \n",
    "    #evaluate siye of training and test set and initialize them\n",
    "    train_set_size = np.int(num_examples*train_to_total_ratio)\n",
    "    test_set_size = num_examples - train_set_size\n",
    "    \n",
    "    train_set = np.zeros((train_set_size, dataset.shape[1]))\n",
    "    test_set = np.zeros((test_set_size, dataset.shape[1]))\n",
    "   \n",
    "\n",
    "    #fill train and test sets\n",
    "    for i in range(num_examples):\n",
    "        if train_set_size > i:\n",
    "            train_set[i,:] += dataset[i,:]\n",
    "        else:\n",
    "            test_set[i - train_set_size,:]  += dataset[i,:]\n",
    "                \n",
    "    return train_set, test_set\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set, test_set = create_random_sets(tset, 0.99)\n",
    "\n",
    "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n",
    "#print(test_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_min_max_scaler(arr, feature_range= (-1,1)):\n",
    "    min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = min_max_scalor.fit_transform(arr)\n",
    "    return min_max_scalor\n",
    "\n",
    "min_max_scalor = set_min_max_scaler(train_set)\n",
    "\n",
    "\n",
    "#transform data\n",
    "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(min_max_scalor.transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = min_max_scalor.transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = min_max_scalor.inverse_transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = min_max_scalor.nverse_transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn - Standard scaler\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_std_scaler(arr):\n",
    "    std_scalor = preprocessing.StandardScaler()\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(std_scalor.fit(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = std_scalor.fit(arr)\n",
    "    return std_scalor\n",
    "\n",
    "std_scalor = set_std_scaler(train_set)\n",
    "\n",
    "#transform data\n",
    "def std_scaler(arr, std_scalor= std_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(std_scalor.transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = std_scalor.transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def std_scaler_inv(arr, std_scalor= std_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        if arr.shape[1] == 8:\n",
    "            arr = reshapor(std_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "        else:            \n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr = reshapor_inv(arr)\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "            arr = reshapor(arr)\n",
    "            \n",
    "    else:\n",
    "        if arr.shape[1] == 24:\n",
    "            arr = std_scalor.inverse_transform(arr)\n",
    "        else:\n",
    "            arr_ = np.zeros((arr.shape[0],24))\n",
    "            arr_[:,:arr.shape[1]] += arr\n",
    "            arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n",
    "    \n",
    "    return arr\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#reshape the data\n",
    "\n",
    "train_set = reshapor(train_set)\n",
    "test_set = reshapor(test_set)\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Scale data either with MinMax scaler or with Standard scaler\n",
    "#Return scalor if fit = True and and scaled array otherwise\n",
    "\n",
    "def scaler(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n",
    "    \n",
    "    if scalerfunc == \"std\":\n",
    "        arr = std_scaler(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler(arr, min_max_scalor= min_max_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n",
    "\n",
    "def scaler_inv(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n",
    "\n",
    "    if scalerfunc == \"std\":\n",
    "        arr = std_scaler_inv(arr, std_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    elif scalerfunc == \"minmax\":\n",
    "        arr = min_max_scaler_inv(arr, min_max_scalor= std_scalor)\n",
    "        return arr\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.02109399  0.0394468  -0.01875739]\n",
      " [-0.0158357   0.02916325 -0.02021501]\n",
      " [-0.00411211  0.01346626 -0.01817778]\n",
      " [-0.00314466  0.01169437 -0.00971874]\n",
      " [ 0.00827457 -0.00905463 -0.00903793]\n",
      " [ 0.00906477 -0.01100179 -0.00610165]\n",
      " [ 0.01623521 -0.02745446  0.00036546]\n",
      " [ 0.01879028 -0.03098714 -0.0009012 ]]\n"
     ]
    }
   ],
   "source": [
    "#scale the data\n",
    "\n",
    "func = \"minmax\"\n",
    "\n",
    "train_set = scaler(train_set, scalerfunc = func)\n",
    "test_set = scaler(test_set, scalerfunc = func)\n",
    "\n",
    "print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "###create random mini_batches###\n",
    "\n",
    "\n",
    "def unison_shuffled_copies(a, b):\n",
    "    assert a.shape[0] == b.shape[0]\n",
    "    p = np.random.permutation(a.shape[0])\n",
    "    return a[p,:,:], b[p,:,:]\n",
    "\n",
    "def random_mini_batches(inputt, target, minibatch_size = 500):\n",
    "    \n",
    "    num_examples = inputt.shape[0]\n",
    "    \n",
    "    \n",
    "    #Number of complete batches\n",
    "    \n",
    "    number_of_batches = int(num_examples/minibatch_size)\n",
    "    minibatches = []\n",
    "   \n",
    "    #shuffle particles\n",
    "    _i, _t = unison_shuffled_copies(inputt, target)\n",
    "    #print(_t.shape)\n",
    "        \n",
    "    \n",
    "    for i in range(number_of_batches):\n",
    "        \n",
    "        minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatches.append((minibatch_train, minibatch_true))\n",
    "        \n",
    "        \n",
    "    minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n",
    "    \n",
    "    \n",
    "    return minibatches\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create random minibatches of train and test set with input and target array\n",
    "\n",
    "\n",
    "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n",
    "#_train, _target = minibatches[0]\n",
    "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n",
    "#print(train[0,:,:], target[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "#minibatches = random_mini_batches(inputt_train, target_train)\n",
    "\n",
    "\n",
    "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n",
    "\n",
    "#print(len(minibatches))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNPlacePrediction():\n",
    "    \n",
    "    \n",
    "    def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\"):\n",
    "        \n",
    "        self.nsteps = time_steps\n",
    "        self.future_steps = future_steps\n",
    "        self.ninputs = ninputs\n",
    "        self.ncells = ncells\n",
    "        self.num_output = num_output\n",
    "        self._ = cell_type #later used to create folder name\n",
    "        self.__ = activation #later used to create folder name\n",
    "        \n",
    "        #### The input is of shape (num_examples, time_steps, ninputs)\n",
    "        #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n",
    "        self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n",
    "        self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n",
    "\n",
    "        \n",
    "        #Check if activation function valid and set activation\n",
    "        if activation==\"relu\":\n",
    "            self.activation = tf.nn.relu\n",
    "            \n",
    "        elif activation==\"tanh\":\n",
    "            self.activation = tf.nn.tanh\n",
    "                    \n",
    "        elif activation==\"leaky_relu\":\n",
    "            self.activation = tf.nn.leaky_relu\n",
    "            \n",
    "        elif activation==\"elu\":\n",
    "            self.activation = tf.nn.elu\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn avtivation function: {}\".format(activation))\n",
    "        \n",
    "        \n",
    "        \n",
    "        #Check if cell type valid and set cell_type\n",
    "        if cell_type==\"basic_rnn\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicRNNCell\n",
    "            \n",
    "        elif cell_type==\"lstm\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicLSTMCell\n",
    "                    \n",
    "        elif cell_type==\"GRU\":\n",
    "            self.cell_type = tf.contrib.rnn.GRUCell\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn cell type: {}\".format(cell_type))\n",
    "            \n",
    "        \n",
    "        #Check Input of ncells        \n",
    "        if (type(self.ncells) == int):\n",
    "            self.ncells = [self.ncells]\n",
    "        \n",
    "        if (type(self.ncells) != list):\n",
    "            raise ValueError(\"Wrong type of Input for ncells\")\n",
    "        \n",
    "        for _ in range(len(self.ncells)):\n",
    "            if type(self.ncells[_]) != int:\n",
    "                raise ValueError(\"Wrong type of Input for ncells\")\n",
    "                \n",
    "        self.activationlist = []\n",
    "        for _ in range(len(self.ncells)-1):\n",
    "            self.activationlist.append(self.activation)\n",
    "        self.activationlist.append(tf.nn.tanh)\n",
    "        \n",
    "        self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n",
    "                                                 for layer in range(len(self.ncells))])\n",
    "            \n",
    "        \n",
    "        #### I now define the output\n",
    "        self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        self.sess = tf.Session()\n",
    "        \n",
    "    def set_cost_and_functions(self, LR=0.001):\n",
    "        #### I define here the function that unrolls the RNN cell\n",
    "        self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n",
    "        #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n",
    "        self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output))   \n",
    "        \n",
    "        #### the rest proceed as usual\n",
    "        self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n",
    "        #### Variable initializer\n",
    "        self.init = tf.global_variables_initializer()\n",
    "        self.saver = tf.train.Saver()\n",
    "        self.sess.run(self.init)\n",
    "  \n",
    "    \n",
    "    def save(self, filename=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.save(self.sess, filename)\n",
    "            \n",
    "            \n",
    "    def load(self, filename=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.restore(self.sess, filename)\n",
    "        \n",
    "        \n",
    "        \n",
    "    def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n",
    "        self.loss_list = []\n",
    "        patience_cnt = 0\n",
    "        epoche_save = 0\n",
    "        \n",
    "        folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n",
    "        \n",
    "        for iep in range(epochs):\n",
    "            loss = 0\n",
    "            \n",
    "            batches = len(minibatches)\n",
    "            #Here I iterate over the batches\n",
    "            for batch in range(batches):\n",
    "            #### Here I train the RNNcell\n",
    "            #### The X is the time series, the Y is shifted by 1 time step\n",
    "                train, target = minibatches[batch]\n",
    "                self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n",
    "                \n",
    "            \n",
    "                loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n",
    "            \n",
    "            #Normalize loss over number of batches and scale it back before normaliziation\n",
    "            loss /= batches\n",
    "            self.loss_list.append(loss)\n",
    "            \n",
    "            #print(loss)\n",
    "            \n",
    "            #Here I create the checkpoint if the perfomance is better\n",
    "            if iep > 1 and iep%checkpoint == 0 and self.loss_list[iep] < self.loss_list[epoche_save]:\n",
    "                #print(\"Checkpoint created at epoch: \", iep)\n",
    "                self.save(folder)\n",
    "                epoche_save = iep\n",
    "            \n",
    "            #early stopping with patience\n",
    "            if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 2/10**7:\n",
    "                patience_cnt += 1\n",
    "                #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n",
    "                \n",
    "                if patience_cnt + 1 > patience:\n",
    "                    print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n",
    "                    print(\"Cost: \",loss)\n",
    "                    break\n",
    "            \n",
    "            #Note that the loss here is multiplied with 1000 for easier reading\n",
    "            if iep%print_step==0:\n",
    "                print(\"Epoch number \",iep)\n",
    "                print(\"Cost: \",loss*10**6, \"e-6\")\n",
    "                print(\"Patience: \",patience_cnt, \"/\", patience)\n",
    "                print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n",
    "        \n",
    "        #Set model back to the last checkpoint if performance was better\n",
    "        if self.loss_list[epoche_save] < self.loss_list[iep]:\n",
    "            self.load(folder)\n",
    "            print(\"\\n\", \"Last checkpoint at epoch \", epoche_save, \" loaded\")\n",
    "            print(\"Performance at last checkpoint is \" ,self.loss_list[iep] - self.loss_list[epoche_save], \" better\" )\n",
    "            \n",
    "            \n",
    "        \n",
    "    def predict(self, x):\n",
    "        return self.sess.run(self.output, feed_dict={self.X:x})\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps = 7\n",
    "future_steps = 1\n",
    "\n",
    "ninputs = 3\n",
    "\n",
    "#ncells as int or list of int\n",
    "ncells = [50, 40, 30, 20, 10]\n",
    "\n",
    "num_output = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn.set_cost_and_functions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch number  0\n",
      "Cost:  10.041199672838395 e-3\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  0 \n",
      "\n",
      "Epoch number  5\n",
      "Cost:  0.14646259134021053 e-3\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  5 \n",
      "\n",
      "Epoch number  10\n",
      "Cost:  0.14038292159811852 e-3\n",
      "Patience:  5 / 200\n",
      "Last checkpoint at: Epoch  10 \n",
      "\n",
      "Epoch number  15\n",
      "Cost:  0.13558934176429868 e-3\n",
      "Patience:  10 / 200\n",
      "Last checkpoint at: Epoch  15 \n",
      "\n",
      "Epoch number  20\n",
      "Cost:  0.12642440127278182 e-3\n",
      "Patience:  14 / 200\n",
      "Last checkpoint at: Epoch  20 \n",
      "\n",
      "Epoch number  25\n",
      "Cost:  0.1116786241912818 e-3\n",
      "Patience:  16 / 200\n",
      "Last checkpoint at: Epoch  25 \n",
      "\n",
      "Epoch number  30\n",
      "Cost:  0.10637743763893129 e-3\n",
      "Patience:  20 / 200\n",
      "Last checkpoint at: Epoch  30 \n",
      "\n",
      "Epoch number  35\n",
      "Cost:  0.10180761176904544 e-3\n",
      "Patience:  21 / 200\n",
      "Last checkpoint at: Epoch  35 \n",
      "\n",
      "Epoch number  40\n",
      "Cost:  0.10329305703325713 e-3\n",
      "Patience:  25 / 200\n",
      "Last checkpoint at: Epoch  35 \n",
      "\n",
      "Epoch number  45\n",
      "Cost:  0.09893714772299567 e-3\n",
      "Patience:  26 / 200\n",
      "Last checkpoint at: Epoch  45 \n",
      "\n",
      "Epoch number  50\n",
      "Cost:  0.09669851916693548 e-3\n",
      "Patience:  28 / 200\n",
      "Last checkpoint at: Epoch  50 \n",
      "\n",
      "Epoch number  55\n",
      "Cost:  0.09474931256919901 e-3\n",
      "Patience:  30 / 200\n",
      "Last checkpoint at: Epoch  55 \n",
      "\n",
      "Epoch number  60\n",
      "Cost:  0.09272654031210163 e-3\n",
      "Patience:  33 / 200\n",
      "Last checkpoint at: Epoch  60 \n",
      "\n",
      "Epoch number  65\n",
      "Cost:  0.09420149952812279 e-3\n",
      "Patience:  35 / 200\n",
      "Last checkpoint at: Epoch  60 \n",
      "\n",
      "Epoch number  70\n",
      "Cost:  0.09541216964630331 e-3\n",
      "Patience:  36 / 200\n",
      "Last checkpoint at: Epoch  60 \n",
      "\n",
      "Epoch number  75\n",
      "Cost:  0.09047800716522962 e-3\n",
      "Patience:  39 / 200\n",
      "Last checkpoint at: Epoch  75 \n",
      "\n",
      "Epoch number  80\n",
      "Cost:  0.09089725666699257 e-3\n",
      "Patience:  39 / 200\n",
      "Last checkpoint at: Epoch  75 \n",
      "\n",
      "Epoch number  85\n",
      "Cost:  0.08590354093726962 e-3\n",
      "Patience:  40 / 200\n",
      "Last checkpoint at: Epoch  85 \n",
      "\n",
      "Epoch number  90\n",
      "Cost:  0.08550771444595041 e-3\n",
      "Patience:  41 / 200\n",
      "Last checkpoint at: Epoch  90 \n",
      "\n",
      "Epoch number  95\n",
      "Cost:  0.08262849370750816 e-3\n",
      "Patience:  42 / 200\n",
      "Last checkpoint at: Epoch  95 \n",
      "\n",
      "Epoch number  100\n",
      "Cost:  0.08081882078066825 e-3\n",
      "Patience:  45 / 200\n",
      "Last checkpoint at: Epoch  100 \n",
      "\n",
      "Epoch number  105\n",
      "Cost:  0.08332692371542624 e-3\n",
      "Patience:  48 / 200\n",
      "Last checkpoint at: Epoch  100 \n",
      "\n",
      "Epoch number  110\n",
      "Cost:  0.0850605532871262 e-3\n",
      "Patience:  50 / 200\n",
      "Last checkpoint at: Epoch  100 \n",
      "\n",
      "Epoch number  115\n",
      "Cost:  0.08140491588571248 e-3\n",
      "Patience:  50 / 200\n",
      "Last checkpoint at: Epoch  100 \n",
      "\n",
      "Epoch number  120\n",
      "Cost:  0.0823190781916987 e-3\n",
      "Patience:  52 / 200\n",
      "Last checkpoint at: Epoch  100 \n",
      "\n",
      "Epoch number  125\n",
      "Cost:  0.0766505290038309 e-3\n",
      "Patience:  55 / 200\n",
      "Last checkpoint at: Epoch  125 \n",
      "\n",
      "Epoch number  130\n",
      "Cost:  0.07502320984210027 e-3\n",
      "Patience:  56 / 200\n",
      "Last checkpoint at: Epoch  130 \n",
      "\n",
      "Epoch number  135\n",
      "Cost:  0.0758755330102855 e-3\n",
      "Patience:  57 / 200\n",
      "Last checkpoint at: Epoch  130 \n",
      "\n",
      "Epoch number  140\n",
      "Cost:  0.0731801113207884 e-3\n",
      "Patience:  58 / 200\n",
      "Last checkpoint at: Epoch  140 \n",
      "\n",
      "Epoch number  145\n",
      "Cost:  0.0745931863499944 e-3\n",
      "Patience:  60 / 200\n",
      "Last checkpoint at: Epoch  140 \n",
      "\n",
      "Epoch number  150\n",
      "Cost:  0.05597170093096793 e-3\n",
      "Patience:  60 / 200\n",
      "Last checkpoint at: Epoch  150 \n",
      "\n",
      "Epoch number  155\n",
      "Cost:  0.0448569248584 e-3\n",
      "Patience:  61 / 200\n",
      "Last checkpoint at: Epoch  155 \n",
      "\n",
      "Epoch number  160\n",
      "Cost:  0.0377340710404864 e-3\n",
      "Patience:  63 / 200\n",
      "Last checkpoint at: Epoch  160 \n",
      "\n",
      "Epoch number  165\n",
      "Cost:  0.03712705128759324 e-3\n",
      "Patience:  64 / 200\n",
      "Last checkpoint at: Epoch  165 \n",
      "\n",
      "Epoch number  170\n",
      "Cost:  0.037240219558527236 e-3\n",
      "Patience:  67 / 200\n",
      "Last checkpoint at: Epoch  165 \n",
      "\n",
      "Epoch number  175\n",
      "Cost:  0.041023939860330774 e-3\n",
      "Patience:  67 / 200\n",
      "Last checkpoint at: Epoch  165 \n",
      "\n",
      "Epoch number  180\n",
      "Cost:  0.03179026030108056 e-3\n",
      "Patience:  69 / 200\n",
      "Last checkpoint at: Epoch  180 \n",
      "\n",
      "Epoch number  185\n",
      "Cost:  0.037844479401370486 e-3\n",
      "Patience:  71 / 200\n",
      "Last checkpoint at: Epoch  180 \n",
      "\n",
      "Epoch number  190\n",
      "Cost:  0.02333719505181665 e-3\n",
      "Patience:  72 / 200\n",
      "Last checkpoint at: Epoch  190 \n",
      "\n",
      "Epoch number  195\n",
      "Cost:  0.02318771433412157 e-3\n",
      "Patience:  77 / 200\n",
      "Last checkpoint at: Epoch  195 \n",
      "\n",
      "Epoch number  200\n",
      "Cost:  0.025808127712151234 e-3\n",
      "Patience:  79 / 200\n",
      "Last checkpoint at: Epoch  195 \n",
      "\n",
      "Epoch number  205\n",
      "Cost:  0.021487966265301518 e-3\n",
      "Patience:  82 / 200\n",
      "Last checkpoint at: Epoch  205 \n",
      "\n",
      "Epoch number  210\n",
      "Cost:  0.020788879447401144 e-3\n",
      "Patience:  85 / 200\n",
      "Last checkpoint at: Epoch  210 \n",
      "\n",
      "Epoch number  215\n",
      "Cost:  0.02056433168810203 e-3\n",
      "Patience:  85 / 200\n",
      "Last checkpoint at: Epoch  215 \n",
      "\n",
      "Epoch number  220\n",
      "Cost:  0.016506806942027438 e-3\n",
      "Patience:  89 / 200\n",
      "Last checkpoint at: Epoch  220 \n",
      "\n",
      "Epoch number  225\n",
      "Cost:  0.020985714496767265 e-3\n",
      "Patience:  91 / 200\n",
      "Last checkpoint at: Epoch  220 \n",
      "\n",
      "Epoch number  230\n",
      "Cost:  0.011625469693520225 e-3\n",
      "Patience:  94 / 200\n",
      "Last checkpoint at: Epoch  230 \n",
      "\n",
      "Epoch number  235\n",
      "Cost:  0.013143771576188614 e-3\n",
      "Patience:  98 / 200\n",
      "Last checkpoint at: Epoch  230 \n",
      "\n",
      "Epoch number  240\n",
      "Cost:  0.017444268354317522 e-3\n",
      "Patience:  100 / 200\n",
      "Last checkpoint at: Epoch  230 \n",
      "\n",
      "Epoch number  245\n",
      "Cost:  0.013935790078942367 e-3\n",
      "Patience:  101 / 200\n",
      "Last checkpoint at: Epoch  230 \n",
      "\n",
      "Epoch number  250\n",
      "Cost:  0.01056458899875771 e-3\n",
      "Patience:  103 / 200\n",
      "Last checkpoint at: Epoch  250 \n",
      "\n",
      "Epoch number  255\n",
      "Cost:  0.013950063650090088 e-3\n",
      "Patience:  106 / 200\n",
      "Last checkpoint at: Epoch  250 \n",
      "\n",
      "Epoch number  260\n",
      "Cost:  0.015239623812800694 e-3\n",
      "Patience:  109 / 200\n",
      "Last checkpoint at: Epoch  250 \n",
      "\n",
      "Epoch number  265\n",
      "Cost:  0.014050647958820845 e-3\n",
      "Patience:  112 / 200\n",
      "Last checkpoint at: Epoch  250 \n",
      "\n",
      "Epoch number  270\n",
      "Cost:  0.009441311336326799 e-3\n",
      "Patience:  112 / 200\n",
      "Last checkpoint at: Epoch  270 \n",
      "\n",
      "Epoch number  275\n",
      "Cost:  0.00812686008391617 e-3\n",
      "Patience:  116 / 200\n",
      "Last checkpoint at: Epoch  275 \n",
      "\n",
      "Epoch number  280\n",
      "Cost:  0.009064912048531968 e-3\n",
      "Patience:  118 / 200\n",
      "Last checkpoint at: Epoch  275 \n",
      "\n",
      "Epoch number  285\n",
      "Cost:  0.007350245905786808 e-3\n",
      "Patience:  119 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  290\n",
      "Cost:  0.009190695427025004 e-3\n",
      "Patience:  123 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  295\n",
      "Cost:  0.009242598896386706 e-3\n",
      "Patience:  126 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  300\n",
      "Cost:  0.009243554339921871 e-3\n",
      "Patience:  131 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  305\n",
      "Cost:  0.008543941756680069 e-3\n",
      "Patience:  134 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  310\n",
      "Cost:  0.008661668753700995 e-3\n",
      "Patience:  137 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  315\n",
      "Cost:  0.008509848796282003 e-3\n",
      "Patience:  142 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  320\n",
      "Cost:  0.009688999833745953 e-3\n",
      "Patience:  145 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  325\n",
      "Cost:  0.010096690673774302 e-3\n",
      "Patience:  148 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  330\n",
      "Cost:  0.008155997478597589 e-3\n",
      "Patience:  152 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  335\n",
      "Cost:  0.012822152828138837 e-3\n",
      "Patience:  156 / 200\n",
      "Last checkpoint at: Epoch  285 \n",
      "\n",
      "Epoch number  340\n",
      "Cost:  0.00638995292552244 e-3\n",
      "Patience:  159 / 200\n",
      "Last checkpoint at: Epoch  340 \n",
      "\n",
      "Epoch number  345\n",
      "Cost:  0.0066921474113924165 e-3\n",
      "Patience:  164 / 200\n",
      "Last checkpoint at: Epoch  340 \n",
      "\n",
      "Epoch number  350\n",
      "Cost:  0.006151222709028862 e-3\n",
      "Patience:  169 / 200\n",
      "Last checkpoint at: Epoch  350 \n",
      "\n",
      "Epoch number  355\n",
      "Cost:  0.006081407573606641 e-3\n",
      "Patience:  170 / 200\n",
      "Last checkpoint at: Epoch  355 \n",
      "\n",
      "Epoch number  360\n",
      "Cost:  0.007673800716494293 e-3\n",
      "Patience:  175 / 200\n",
      "Last checkpoint at: Epoch  355 \n",
      "\n",
      "Epoch number  365\n",
      "Cost:  0.0072596388893911585 e-3\n",
      "Patience:  180 / 200\n",
      "Last checkpoint at: Epoch  355 \n",
      "\n",
      "Epoch number  370\n",
      "Cost:  0.006717292966099427 e-3\n",
      "Patience:  184 / 200\n",
      "Last checkpoint at: Epoch  355 \n",
      "\n",
      "Epoch number  375\n",
      "Cost:  0.006316999443175093 e-3\n",
      "Patience:  189 / 200\n",
      "Last checkpoint at: Epoch  355 \n",
      "\n",
      "Epoch number  380\n",
      "Cost:  0.006750347554461382 e-3\n",
      "Patience:  193 / 200\n",
      "Last checkpoint at: Epoch  355 \n",
      "\n",
      "Epoch number  385\n",
      "Cost:  0.006520240363665544 e-3\n",
      "Patience:  198 / 200\n",
      "Last checkpoint at: Epoch  355 \n",
      "\n",
      "\n",
      " Early stopping at epoch  387 , difference:  9.317458766708246e-07\n",
      "Cost:  5.49251195054162e-06\n"
     ]
    }
   ],
   "source": [
    "rnn.fit(minibatches, epochs = 5000, print_step=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEKCAYAAAA4t9PUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHq5JREFUeJzt3X2QXfV93/H35967u1o9IIxQXJAwkoPSRHj8uFH9VE9rUls4nshp5SDGTpiUKa0LtZ00cVA9JS5TpqVpjc0YO0MMMSaOBZXt8Y5LjB0LJ+PWkVhsbCOI7C3gsgabxSAhkPbh3vvtH+e3q6vLvecsu3t2L+LzmtnZc3/3nHO/9+zu/ezvPPyOIgIzM7P5qix3AWZm9sLmIDEzswVxkJiZ2YI4SMzMbEEcJGZmtiAOEjMzWxAHiZmZLYiDxMzMFsRBYmZmC1Jb7gKWwplnnhmbNm1a7jLMzF4w7rnnniciYv1c5n1RBMmmTZsYGRlZ7jLMzF4wJP14rvN615aZmS2Ig8TMzBbEQWJmZgviIDEzswVxkJiZ2YKUGiSStks6JGlU0pUdnh+QdFt6fr+kTal9naS7JD0j6RNty7xO0g/SMtdLUpnvwczM8pUWJJKqwA3AhcBW4GJJW9tmuxR4KiLOA64Drk3tE8B/BP6gw6o/BVwGbElf2xe/ejMzm6syeyTbgNGIeDAipoA9wI62eXYAt6TpvcAFkhQRz0bEt8gCZZaks4DTIuLbkd0j+LPAu8p6A9d/40f8zQ/Hy1q9mdkpocwg2QA80vJ4LLV1nCci6sARYF3BOscK1rloPvnNUf736BNlrd7M7JRQZpB0OnYR85hnXvNLukzSiKSR8fH59SqEyDo+ZmbWTZlBMgac0/J4I/Bot3kk1YC1wJMF69xYsE4AIuLGiBiKiKH16+c0XMxzSOAcMTPLV2aQ3A1skbRZUj+wCxhum2cYuCRN7wT2RU4XICIeA45Ken06W+t3gC8vfukZkd89MjOzEgdtjIi6pCuAO4EqcHNEHJR0NTASEcPATcCtkkbJeiK7ZpaX9DBwGtAv6V3A2yLifuB9wGeAQeCv0lcpJLlHYmZWoNTRfyPiDuCOtrarWqYngHd3WXZTl/YR4BWLV2V3WY/ESWJmlsdXtufxMRIzs0IOkhy+ZN7MrJiDJEd2jMRdEjOzPA6SHJLP2jIzK+IgySF8jMTMrIiDJIckn7VlZlbAQZKjImg6R8zMcjlIcvmCRDOzIg6SHNkts5wkZmZ5HCQ5fLDdzKyYgySHR/81MyvmIMkhfNaWmVkRB0kO90jMzIo5SHL4fiRmZsUcJDl8PxIzs2IOkgI+RmJmls9BkkPet2VmVshBksOj/5qZFXOQ5BC+H4mZWREHSQ73SMzMijlIcniIFDOzYg6SHNn9SMzMLI+DJEfWI3GUmJnlcZDk8TESM7NCDpIcvh2JmVkxB0kO37PdzKyYgySHz9oyMyvmIMnhYeTNzIo5SHL4xlZmZsUcJDncIzEzK+YgKeAcMTPLV2qQSNou6ZCkUUlXdnh+QNJt6fn9kja1PLc7tR+S9PaW9t+TdFDSfZI+L2lFifW7R2JmVqC0IJFUBW4ALgS2AhdL2to226XAUxFxHnAdcG1adiuwCzgf2A58UlJV0gbg/cBQRLwCqKb5ynkP+Mp2M7MiZfZItgGjEfFgREwBe4AdbfPsAG5J03uBCyQpte+JiMmIeAgYTesDqAGDkmrASuDRst5ApeJdW2ZmRcoMkg3AIy2Px1Jbx3kiog4cAdZ1WzYifgL8d+D/AY8BRyLia51eXNJlkkYkjYyPj8/rDfh+JGZmxcoMEnVoa/9U7jZPx3ZJLyHrrWwGzgZWSXpvpxePiBsjYigihtavX/88ym4pzmNtmZkVKjNIxoBzWh5v5Lm7oWbnSbuq1gJP5iz7a8BDETEeEdPAF4E3llI9vrLdzGwuygySu4EtkjZL6ic7KD7cNs8wcEma3gnsi2xf0jCwK53VtRnYAhwg26X1ekkr07GUC4AHSnsHvh+JmVmhWlkrjoi6pCuAO8nOrro5Ig5KuhoYiYhh4CbgVkmjZD2RXWnZg5JuB+4H6sDlEdEA9kvaC3wntX8XuLGs9+CztszMipUWJAARcQdwR1vbVS3TE8C7uyx7DXBNh/Y/Bv54cSvtTJ2O1JiZ2Ul8ZXsOHyMxMyvmIMnh+5GYmRVzkORwj8TMrJiDJIdH/zUzK+YgyeH7kZiZFXOQ5HGPxMyskIMkh/AQKWZmRRwkOeQkMTMr5CDJ4WMkZmbFHCQ5fNaWmVkxB0kODyNvZlbMQZLDN7YyMyvmIMnhHomZWTEHSQF3SMzM8jlIcsg3tjIzK+QgySFwl8TMrICDJIePkZiZFXOQ5BDQdI/EzCyXgySHJO/ZMjMr4CDJUfGV7WZmhRwkuXzWlplZEQdJjmysLUeJmVkeB0kOLXcBZmYvAA6SHB7918ysmIMkh+9HYmZWzEGSwz0SM7NiDpIcvrLdzKyYgySH70diZlbMQZLHPRIzs0IOkhzZ6L/LXYWZWW8rNUgkbZd0SNKopCs7PD8g6bb0/H5Jm1qe253aD0l6e0v76ZL2Svp7SQ9IekOJ9TtHzMwKlBYkkqrADcCFwFbgYklb22a7FHgqIs4DrgOuTctuBXYB5wPbgU+m9QF8HPhqRPwy8CrggdLeA76y3cysSJk9km3AaEQ8GBFTwB5gR9s8O4Bb0vRe4AJJSu17ImIyIh4CRoFtkk4D3gLcBBARUxFxuKw34LO2zMyKlRkkG4BHWh6PpbaO80REHTgCrMtZ9uXAOPDnkr4r6dOSVpVT/kyPpKy1m5mdGsoMkk5DVbV/LHebp1t7DXgt8KmIeA3wLPCcYy8Aki6TNCJpZHx8fO5Vn7wOX9luZlagzCAZA85pebwReLTbPJJqwFrgyZxlx4CxiNif2veSBctzRMSNETEUEUPr16+f1xtwj8TMrFiZQXI3sEXSZkn9ZAfPh9vmGQYuSdM7gX2RHd0eBnals7o2A1uAAxHxU+ARSf8wLXMBcH9p78BDpJiZFaqVteKIqEu6ArgTqAI3R8RBSVcDIxExTHbQ/FZJo2Q9kV1p2YOSbicLiTpweUQ00qr/HfC5FE4PAr9b1nuQB5I3MytUWpAARMQdwB1tbVe1TE8A7+6y7DXANR3a7wWGFrfSznxjKzOzYr6yPYfw6b9mZkUcJDk8jLyZWTEHSQ7f2MrMrJiDJIcETeeImVkuB0kO79oyMys2pyCRdOtc2k49PtxuZlZkrj2S81sfpJF4X7f45fSWinskZmaFcoMk3RPkKPBKSU+nr6PA48CXl6TCZeTRf83MiuUGSUT8l4hYA/xJRJyWvtZExLqI2L1ENS4b37PdzKzYXHdtfWVmuHZJ75X0UUnnllhXT3CPxMys2FyD5FPAMUmvAj4E/Bj4bGlV9QiP/mtmVmyuQVJPo/LuAD4eER8H1pRXVm+QvGvLzKzIXAdtPCppN/DbwD9OZ231lVdW73CMmJnlm2uP5CJgEviX6Z4gG4A/Ka2qHiFfRmJmVmhOQZLC43PAWknvBCYi4kVwjETOETOzAnO9sv23gANk9w75LWC/pJ1lFtYLfD8SM7Nicz1G8mHgVyPicQBJ64G/Jrtn+inLe7bMzIrN9RhJZSZEkp8/j2VfsDxoo5lZsbn2SL4q6U7g8+nxRbTdQvdUJPl+JGZmRXKDRNJ5wEsj4g8l/XPgzWR7fL5NdvD9lOYLEs3MihXtnvoYcBQgIr4YEb8fEb9H1hv5WNnFLTsPkWJmVqgoSDZFxPfbGyNiBNhUSkU9RE4SM7NCRUGyIue5wcUspBdlgzY6SczM8hQFyd2S/lV7o6RLgXvKKal3+BiJmVmxorO2Pgh8SdJ7OBEcQ0A/8JtlFtYLPIy8mVmx3CCJiJ8Bb5T0T4FXpOb/FRH7Sq+sB/jGVmZmxeZ0HUlE3AXcVXItPcc9EjOzYqf81ekL4WMkZmbFHCR5JMADN5qZ5XGQ5FD67hwxM+vOQZIjdUh8nMTMLEepQSJpu6RDkkYlXdnh+QFJt6Xn90va1PLc7tR+SNLb25arSvqupK+UWj/etWVmVqS0IEn3db8BuBDYClwsaWvbbJcCT0XEecB1wLVp2a3ALuB8YDvwybS+GR8AHiir9hnukZiZFSuzR7INGI2IByNiCtgD7GibZwdwS5reC1wgSal9T0RMRsRDwGhaH5I2Ar8OfLrE2gGozASJk8TMrKsyg2QD8EjL47HU1nGeiKgDR4B1Bct+DPgQ0Mx7cUmXSRqRNDI+Pj6vN6CZs7bcJzEz66rMIFGHtvZP5G7zdGyX9E7g8YgoHOcrIm6MiKGIGFq/fn1xtbnrWtDiZmantDKDZAw4p+XxRuDRbvNIqgFrgSdzln0T8BuSHibbVfZWSX9RRvFZTWWt2czs1FFmkNwNbJG0WVI/2cHz4bZ5hoFL0vROYF9kp0gNA7vSWV2bgS3AgYjYHREbI2JTWt++iHhvWW/gxFlbZb2CmdkL31zv2f68RURd0hXAnUAVuDkiDkq6GhiJiGHgJuBWSaNkPZFdadmDkm4H7gfqwOUR0Sir1m5OnLXlJDEz66a0IAGIiDvIbsvb2nZVy/QE8O4uy14DXJOz7m8C31yMOrvxle1mZsV8ZXsOX0diZlbMQZLDV7abmRVzkORwj8TMrJiDZA7cITEz685BkkPukpiZFXKQ5Jg9a8tJYmbWlYMkhzxoo5lZIQdJjhM9EjMz68ZBkkO+Z7uZWSEHSQ4fazczK+YgyeEhUszMijlI8vjGVmZmhRwkOWZvR+IcMTPrykGSw8dIzMyKOUhyzAza2PRBEjOzrhwkOXxBoplZMQdJDl+QaGZWzEGS40SPxFFiZtaNgyTHiRtbLXMhZmY9zEGSQyqex8zsxc5BkuPEWFvLXIiZWQ9zkOTw/UjMzIo5SHL49F8zs2IOkhy+st3MrJiDJMeJs7YcJWZm3ThIcrhHYmZWzEEyB+6QmJl15yDJIXmQFDOzIg6SHL5DoplZMQdJDh8jMTMrVmqQSNou6ZCkUUlXdnh+QNJt6fn9kja1PLc7tR+S9PbUdo6kuyQ9IOmgpA+UWr/H2jIzK1RakEiqAjcAFwJbgYslbW2b7VLgqYg4D7gOuDYtuxXYBZwPbAc+mdZXB/59RPwK8Hrg8g7rXMT3kH33le1mZt2V2SPZBoxGxIMRMQXsAXa0zbMDuCVN7wUuUHaEewewJyImI+IhYBTYFhGPRcR3ACLiKPAAsKGsN+BjJGZmxcoMkg3AIy2Px3juh/7sPBFRB44A6+aybNoN9hpg/yLWfBIPkWJmVqzMIOk0CHv7R3K3eXKXlbQa+ALwwYh4uuOLS5dJGpE0Mj4+PseSn7OW9MJOEjOzbsoMkjHgnJbHG4FHu80jqQasBZ7MW1ZSH1mIfC4ivtjtxSPixogYioih9evXz+sNuEdiZlaszCC5G9giabOkfrKD58Nt8wwDl6TpncC+yAa2GgZ2pbO6NgNbgAPp+MlNwAMR8dESawc6d4vMzOxktbJWHBF1SVcAdwJV4OaIOCjpamAkIobJQuFWSaNkPZFdadmDkm4H7ic7U+vyiGhIejPw28APJN2bXuo/RMQdZbwH39jKzKxYaUECkD7g72hru6plegJ4d5dlrwGuaWv7FkvYUfCNrczMivnK9hwzx0iazhEzs64cJDlOHGx3kpiZdeMgyTE7RMoy12Fm1sscJHl8+q+ZWSEHSY4TR/WdJGZm3ThIcvj0XzOzYg6SHL4/oplZMQdJjop7JGZmhRwkOXz6r5lZMQdJDu/aMjMr5iDJ49N/zcwKOUhyyPcjMTMr5CDJIe/bMjMr5CDJ4RwxMyvmIMnhCxLNzIo5SHLMnv7rPomZWVcOkhyzu7acI2ZmXTlIcpzokZiZWTcOklwzx0gcJWZm3ThIcrhHYmZWzEGSY/Z+JE4SM7OuHCQ5Zk//dZKYmXXlIMnhs7bMzIo5SHLIgzaamRVykOQ4MWijmZl14yDJMdMjabpLYmbWlYNkDpwjZmbdOUhyyOf/mpkVcpDkmD1G4hwxM+vKQZJjpkcy1WgubyFmZj2sVubKJW0HPg5UgU9HxH9te34A+CzwOuDnwEUR8XB6bjdwKdAA3h8Rd85lnYtpw0sGecnKPq768kH23jPGulX9rOirsqKvykBfhRW1anpcOfG9duL5gVqVVQNVzljZz+kr++mvObfN7NRTWpBIqgI3AP8MGAPuljQcEfe3zHYp8FREnCdpF3AtcJGkrcAu4HzgbOCvJf1SWqZonYvmtBV93Pav38An9o3y0BPP8vDPn2ViusnkdIOJepOp+vPrqazqr3L6yn4G+ipEZINBBtBoBhPTTV6+fhUDtQrVStYVmphuMNhXZaBW5ZnJ+mxgTUw3GahVqFTE8ak6jWZw1umDiOxoTrYrLtJrQK0qVg/UOD7dYEVflXojCIJGMxjsrzJQrYBEo9mk3gxW1KpUJAb6KhybarCqv8p0o0kzoCKYagSnD/YxUW/QV6nM1lVvBvXmif2Aannvanmglmdm2mdaKhWxfvUAz0zW6a9V6K9VGKhWeGayTl+1Mrv+/lqFgRTMzWYw0FelryqOTtSpVcSaFX3Um02EWDvYx9qVfawZqFGptFZlZouhzB7JNmA0Ih4EkLQH2AG0fujvAD6SpvcCn1A2LskOYE9ETAIPSRpN62MO61xUv/TSNVx/8Ws6PtdsBpP1JhPTDSbqDSanm0zUG0xMp7bpBs9ONnjq2BRPPjvFkePTHD42zUS9QUWiouwDtCJRqYiHn3iWoxP12dONV9SqPPHMFBPTDVavqPHUsSbHphoM1CpMN5opCGoI+P7YEbJtAiCU1i3BVL3Js1NZKE1MN2aDqr+WBcVMINYqoloRU43mKXlcqCJYO9jHaYN9HJ2os23TGZx9+iDHpxv0V8WxqQYSnLtuFdWKqEocOT5NtaIU4A36qtnPqirNbsfDx6Y5c3U/Z58+yHQjaEQw2FdlZX+VqUaTnx2Z4Pyz1/KTw8f4+v2Ps/N1G1k72HdSbVL6avvZQfZPQLUiJqYbNJpBrSqmG8FTx6Z42RkrOT7VoFLJfp+qEkq/W9XKydMVZeuuSGm+bBig6UaT6UaTlf2dPw6azaAZQa26/D3qiens91/yPwS9pMwg2QA80vJ4DPhH3eaJiLqkI8C61P53bctuSNNF61wylYoY7K8y2F9drhIWxcww+TN/nNPpmNBMj2ii3qS/WkGCeiNm//Nf0Vdlqt5kstFgYqpJX03UKtmHzUnjk3WenA2r1nmn6k3Gj05y2mAfU/UmU40mk9NNBvurNJrN2V7JdL3JdCOoKPs5TKYe4poVNSbrTY5P1alVKjQjeHqizuFjUxw+Ns3h41McOZ71Wg489CR/+6NxVvbXmG40qVZEvdHk6Yn6bD3Vimg081NVen4nZHzhO2Nzn7lkM7VLsLq/xnQz+wdluhFUK2KgVqGewnGgVqGRQkUpjCbqDfqrFfqqFVLupRBsDUSdFIytQXnk+HS2W7hWpdG2Edu3aUTw82enOHN1PwO16myPvhlBM/W+K4JjUw0G+6vUKjrpdyzixO9fJQXqTC++3gyazZgN3ueGeva3Ualk7RWdeF/MI9PmG4PPN0DPWNnP7f/mDfN8tbkrM0g6veP2P7du83Rr7/QvUcc/YUmXAZcBvOxlL+tepT3nl7Mv/ec58311y3+ifSkzX7KqHyCF6Mn/XS/UuetWLer6no9mM5hqNGlG2vWX3vBkvclg2n0381wjsg+n01bU+OnTE/z8mSn6axUqguNTTY5N1WlEsHawj7GnjrNmoMZ5v7Caex85zMnZdOJDLvve+jg4OlEnItt9V6todvfeQK3C+NFJ1g72EQGNiJYP1Ugf+idPNyNSD2PmAzgLDJH1vmpVUatkX82AyXrW2+mrVGZ7s5WKZtcz88/EdNpdOvvBnT7kn/t+Wt5vwKqBGhP1Bo1GkP0PcvLvYvvn5pmrB3j08HGaES29es0u22wGKweqHJ/Kem8zgTCzrtZhj2a2FUBfVemfhlR7h7pnH88G2PzuVTTvzv48FlyzotTD4LPKfJUx4JyWxxuBR7vMMyapBqwFnixYtmidAETEjcCNAENDQ6fgjhorQ6UiVlSe28Oc2a3T3+UYy1lrBzlr7WDX9Z5/9trZ6bed/w8WWKVZbylzp+fdwBZJmyX1kx08H26bZxi4JE3vBPZFFvHDwC5JA5I2A1uAA3Ncp5mZLaHSeiTpmMcVwJ1kp+reHBEHJV0NjETEMHATcGs6mP4kWTCQ5rud7CB6Hbg8IhoAndZZ1nswM7NiejHcj3xoaChGRkaWuwwzsxcMSfdExNBc5l3+8/nMzOwFzUFiZmYL4iAxM7MFcZCYmdmCOEjMzGxBXhRnbUkaB348z8XPBJ5YxHIWk2ubv16uz7XNXy/X90Kr7dyIWD+XhV8UQbIQkkbmegrcUnNt89fL9bm2+evl+k7l2rxry8zMFsRBYmZmC+IgKXbjcheQw7XNXy/X59rmr5frO2Vr8zESMzNbEPdIzMxsQRwkXUjaLumQpFFJVy53PQCSHpb0A0n3ShpJbWdI+rqkH6XvL1miWm6W9Lik+1raOtaizPVpW35f0muXobaPSPpJ2nb3SnpHy3O7U22HJL295NrOkXSXpAckHZT0gdTeK9uuW33Lvv0krZB0QNL3Um3/KbVvlrQ/bbvb0i0mSLehuC3Vtl/SpmWo7TOSHmrZbq9O7Uv6c02vWZX0XUlfSY8Xb7tldwPzV+sX2RD1/xd4OdAPfA/Y2gN1PQyc2db234Ar0/SVwLVLVMtbgNcC9xXVArwD+Cuy29+9Hti/DLV9BPiDDvNuTT/fAWBz+rlXS6ztLOC1aXoN8MNUQ69su271Lfv2S9tgdZruA/anbXI7sCu1/ynwvjT9b4E/TdO7gNtK3G7davsMsLPD/Ev6c02v+fvAXwJfSY8Xbbu5R9LZNmA0Ih6MiClgD7BjmWvqZgdwS5q+BXjXUrxoRPwt2T1k5lLLDuCzkfk74HRJZy1xbd3sAPZExGREPASMkv38y6rtsYj4Tpo+CjwAbKB3tl23+rpZsu2XtsEz6WFf+grgrcDe1N6+7Wa26V7gAul53vR84bV1s6Q/V0kbgV8HPp0ei0Xcbg6SzjYAj7Q8HiP/j2mpBPA1Sfcouyc9wEsj4jHIPgSAX1i26rrX0ivb84q0G+Hmll2Ay1Zb2mXwGrL/Xntu27XVBz2w/dLumXuBx4Gvk/WADkdEvcPrz9aWnj8CrFuq2iJiZrtdk7bbdZIG2mvrUHcZPgZ8CGimx+tYxO3mIOmsU/r2wultb4qI1wIXApdLestyFzRHvbA9PwX8IvBq4DHgf6T2ZalN0mrgC8AHI+LpvFk7tC1HfT2x/SKiERGvBjaS9Xx+Jef1l7U2Sa8AdgO/DPwqcAbwR0tdm6R3Ao9HxD2tzTmv/7xrc5B0Ngac0/J4I/DoMtUyKyIeTd8fB75E9of0s5kucfr++PJV2LWWZd+eEfGz9IfeBP6ME7tflrw2SX1kH9Kfi4gvpuae2Xad6uul7ZfqOQx8k+z4wumSZm4b3vr6s7Wl59cy912ei1Hb9rSrMCJiEvhzlme7vQn4DUkPk+2mfytZD2XRtpuDpLO7gS3prIZ+sgNOw8tZkKRVktbMTANvA+5LdV2SZrsE+PLyVAg5tQwDv5POVHk9cGRmN85Sadv//Jtk226mtl3pTJXNwBbgQIl1CLgJeCAiPtryVE9su2719cL2k7Re0ulpehD4NbJjOHcBO9Ns7dtuZpvuBPZFOoK8RLX9fcs/ByI7BtG63Zbk5xoRuyNiY0RsIvss2xcR72Ext1vZZwq8UL/Izqr4Idk+2A/3QD0vJzs75nvAwZmayPZdfgP4Ufp+xhLV83myXRzTZP/BXNqtFrKu8g1pW/4AGFqG2m5Nr/399IdyVsv8H061HQIuLLm2N5PtJvg+cG/6ekcPbbtu9S379gNeCXw31XAfcFXL38YBsgP9/xMYSO0r0uPR9PzLl6G2fWm73Qf8BSfO7FrSn2tLnf+EE2dtLdp285XtZma2IN61ZWZmC+IgMTOzBXGQmJnZgjhIzMxsQRwkZma2IA4Ss0UgqdEywuu9WsQRoyVtUstIxma9plY8i5nNwfHIhscwe9Fxj8SsRMruIXNtulfFAUnnpfZzJX0jDeb3DUkvS+0vlfQlZfe1+J6kN6ZVVSX9mbJ7XXwtXT1t1hMcJGaLY7Bt19ZFLc89HRHbgE+QjXFEmv5sRLwS+BxwfWq/HvibiHgV2T1VDqb2LcANEXE+cBj4FyW/H7M585XtZotA0jMRsbpD+8PAWyPiwTQY4k8jYp2kJ8iGGZlO7Y9FxJmSxoGNkQ3yN7OOTWTDkm9Jj/8I6IuI/1z+OzMr5h6JWfmiy3S3eTqZbJlu4OOb1kMcJGblu6jl+7fT9P8hG4kV4D3At9L0N4D3weyNkk5bqiLN5sv/1ZgtjsF0d7wZX42ImVOAByTtJ/vH7eLU9n7gZkl/CIwDv5vaPwDcKOlSsp7H+8hGMjbrWT5GYlaidIxkKCKeWO5azMriXVtmZrYg7pGYmdmCuEdiZmYL4iAxM7MFcZCYmdmCOEjMzGxBHCRmZrYgDhIzM1uQ/w/ptlV1wVOjOwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#Plot the loss\n",
    "\n",
    "plt.plot(rnn.loss_list)\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Cost\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "#save in a folder that describes the model\n",
    "folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "rnn.save(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm_leaky_relu_5l_[50,40,30,20,10]c/rnn_basic\n"
     ]
    }
   ],
   "source": [
    "#folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "#rnn.load(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "###test_input.shape###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here I predict based on my test set\n",
    "\n",
    "test_pred = rnn.predict(test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.00610282  0.00100984  0.02600916]\n",
      " [ 0.01632101 -0.01520294  0.02987524]\n",
      " [ 0.06068288  0.00697896  0.06441782]\n",
      " [-0.0119639  -0.04535145  0.07225598]\n",
      " [ 0.04132241  0.01145548  0.05150088]\n",
      " [-0.03290992  0.10355402  0.09310361]\n",
      " [ 0.00265487  0.04124176  0.08941123]]\n"
     ]
    }
   ],
   "source": [
    "#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n",
    "\n",
    "#scaler_inv(test_input, scalerfunc = func)[0,:,:]\n",
    "\n",
    "\n",
    "diff = scaler_inv(test_pred, scalerfunc = func)-scaler_inv(test_target, scalerfunc = func )\n",
    "\n",
    "print(diff[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7.513113e-06"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Here I evaluate my model on the test set based on mean_squared_error\n",
    "\n",
    "rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}