Newer
Older
rnn_bachelor_thesis / 1_to_1_multi_layer.ipynb
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops\n",
    "from sklearn import preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import data as array\n",
    "# 8 hits with x,y,z\n",
    "\n",
    "testset = pd.read_pickle('matched_8hittracks.pkl')\n",
    "#print(testset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Check testset with arbitrary particle\n",
    "\n",
    "tset = np.array(testset)\n",
    "tset = tset.astype('float32')\n",
    "#print(tset.shape)\n",
    "#for i in range(8):\n",
    "    #print(tset[1,3*i:(3*i+3)])\n",
    "#print(tset[0,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n",
    "\n",
    "def reshapor(arr_orig):\n",
    "    timesteps = int(arr_orig.shape[1]/3)\n",
    "    number_examples = int(arr_orig.shape[0])\n",
    "    arr = np.zeros((number_examples, timesteps, 3))\n",
    "    \n",
    "    for i in range(number_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n",
    "        \n",
    "    return arr\n",
    "\n",
    "def reshapor_inv(array_shaped):\n",
    "    timesteps = int(array_shaped.shape[1])\n",
    "    num_examples = int(array_shaped.shape[0])\n",
    "    arr = np.zeros((num_examples, timesteps*3))\n",
    "    \n",
    "    for i in range(num_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n",
    "        \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "### create the training set and the test set###\n",
    "\n",
    "def create_random_sets(dataset, train_to_total_ratio):\n",
    "    #shuffle the dataset\n",
    "    num_examples = dataset.shape[0]\n",
    "    p = np.random.permutation(num_examples)\n",
    "    dataset = dataset[p,:]\n",
    "    \n",
    "    #evaluate siye of training and test set and initialize them\n",
    "    train_set_size = np.int(num_examples*train_to_total_ratio)\n",
    "    test_set_size = num_examples - train_set_size\n",
    "    \n",
    "    train_set = np.zeros((train_set_size, dataset.shape[1]))\n",
    "    test_set = np.zeros((test_set_size, dataset.shape[1]))\n",
    "   \n",
    "\n",
    "    #fill train and test sets\n",
    "    for i in range(num_examples):\n",
    "        if train_set_size > i:\n",
    "            train_set[i,:] += dataset[i,:]\n",
    "        else:\n",
    "            test_set[i - train_set_size,:]  += dataset[i,:]\n",
    "                \n",
    "    return train_set, test_set\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set, test_set = create_random_sets(tset, 0.99)\n",
    "\n",
    "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n",
    "#print(test_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize the data advanced version with scikit learn\n",
    "\n",
    "#set the transormation based on training set\n",
    "def set_min_max_scalor(arr, feature_range= (-1,1)):\n",
    "    min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = min_max_scalor.fit_transform(arr)\n",
    "    return min_max_scalor\n",
    "\n",
    "min_max_scalor = set_min_max_scalor(train_set)\n",
    "\n",
    "\n",
    "#transform data\n",
    "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n",
    "    \n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(min_max_scalor.transform(reshapor_inv(arr)))        \n",
    "    else:\n",
    "        arr = min_max_scalor.transform(arr)\n",
    "    \n",
    "    return arr\n",
    "        \n",
    "#inverse transformation\n",
    "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n",
    "    if len(arr.shape) == 3:\n",
    "        arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n",
    "    else:\n",
    "        arr = min_max_scalor.inverse_transform(arr)\n",
    "    \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = reshapor(train_set)\n",
    "test_set = reshapor(test_set)\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = min_max_scaler(train_set)\n",
    "test_set = min_max_scaler(test_set)\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_set = min_max_scaler_inv(train_set)\n",
    "\n",
    "#print(train_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "###create random mini_batches###\n",
    "\n",
    "\n",
    "def unison_shuffled_copies(a, b):\n",
    "    assert a.shape[0] == b.shape[0]\n",
    "    p = np.random.permutation(a.shape[0])\n",
    "    return a[p,:,:], b[p,:,:]\n",
    "\n",
    "def random_mini_batches(inputt, target, minibatch_size = 500):\n",
    "    \n",
    "    num_examples = inputt.shape[0]\n",
    "    \n",
    "    \n",
    "    #Number of complete batches\n",
    "    \n",
    "    number_of_batches = int(num_examples/minibatch_size)\n",
    "    minibatches = []\n",
    "   \n",
    "    #shuffle particles\n",
    "    _i, _t = unison_shuffled_copies(inputt, target)\n",
    "    #print(_t.shape)\n",
    "        \n",
    "    \n",
    "    for i in range(number_of_batches):\n",
    "        \n",
    "        minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatches.append((minibatch_train, minibatch_true))\n",
    "        \n",
    "        \n",
    "    minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n",
    "    \n",
    "    \n",
    "    return minibatches\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create random minibatches of train and test set with input and target array\n",
    "\n",
    "\n",
    "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n",
    "#_train, _target = minibatches[0]\n",
    "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n",
    "#print(train[0,:,:], target[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#minibatches = random_mini_batches(inputt_train, target_train)\n",
    "\n",
    "\n",
    "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n",
    "\n",
    "#print(len(minibatches))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNPlacePrediction():\n",
    "    \n",
    "    \n",
    "    def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\"):\n",
    "        \n",
    "        self.nsteps = time_steps\n",
    "        self.future_steps = future_steps\n",
    "        self.ninputs = ninputs\n",
    "        self.ncells = ncells\n",
    "        self.num_output = num_output\n",
    "        self._ = cell_type #later used to create folder name\n",
    "        self.__ = activation #later used to create folder name\n",
    "        \n",
    "        #### The input is of shape (num_examples, time_steps, ninputs)\n",
    "        #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n",
    "        self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n",
    "        self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n",
    "\n",
    "        \n",
    "        #Check if activation function valid and set activation\n",
    "        if activation==\"relu\":\n",
    "            self.activation = tf.nn.relu\n",
    "            \n",
    "        elif activation==\"tanh\":\n",
    "            self.activation = tf.nn.tanh\n",
    "                    \n",
    "        elif activation==\"leaky_relu\":\n",
    "            self.activation = tf.nn.leaky_relu\n",
    "            \n",
    "        elif activation==\"elu\":\n",
    "            self.activation = tf.nn.elu\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn avtivation function: {}\".format(activation))\n",
    "        \n",
    "        \n",
    "        \n",
    "        #Check if cell type valid and set cell_type\n",
    "        if cell_type==\"basic_rnn\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicRNNCell\n",
    "            \n",
    "        elif cell_type==\"lstm\":\n",
    "            self.cell_type = tf.contrib.rnn.BasicLSTMCell\n",
    "                    \n",
    "        elif cell_type==\"GRU\":\n",
    "            self.cell_type = tf.contrib.rnn.GRUCell\n",
    "            \n",
    "        else:\n",
    "            raise ValueError(\"Wrong rnn cell type: {}\".format(cell_type))\n",
    "            \n",
    "        \n",
    "        #Check Input of ncells        \n",
    "        if (type(self.ncells) == int):\n",
    "            self.ncells = [self.ncells]\n",
    "        \n",
    "        if (type(self.ncells) != list):\n",
    "            raise ValueError(\"Wrong type of Input for ncells\")\n",
    "        \n",
    "        for _ in range(len(self.ncells)):\n",
    "            if type(self.ncells[_]) != int:\n",
    "                raise ValueError(\"Wrong type of Input for ncells\")\n",
    "                \n",
    "        self.activationlist = []\n",
    "        for _ in range(len(self.ncells)-1):\n",
    "            self.activationlist.append(self.activation)\n",
    "        self.activationlist.append(tf.nn.tanh)\n",
    "        \n",
    "        self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n",
    "                                                 for layer in range(len(self.ncells))])\n",
    "            \n",
    "        \n",
    "        #### I now define the output\n",
    "        self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        self.sess = tf.Session()\n",
    "        \n",
    "    def set_cost_and_functions(self, LR=0.001):\n",
    "        #### I define here the function that unrolls the RNN cell\n",
    "        self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n",
    "        #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n",
    "        self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output))   \n",
    "        \n",
    "        #### the rest proceed as usual\n",
    "        self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n",
    "        #### Variable initializer\n",
    "        self.init = tf.global_variables_initializer()\n",
    "        self.saver = tf.train.Saver()\n",
    "        self.sess.run(self.init)\n",
    "  \n",
    "    \n",
    "    def save(self, filename=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.save(self.sess, filename)\n",
    "            \n",
    "            \n",
    "    def load(self, filename=\"./rnn_model/rnn_basic\"):\n",
    "        self.saver.restore(self.sess, filename)\n",
    "        \n",
    "        \n",
    "        \n",
    "    def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n",
    "        self.loss_list = []\n",
    "        patience_cnt = 0\n",
    "        epoche_save = 0\n",
    "        \n",
    "        folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n",
    "        \n",
    "        for iep in range(epochs):\n",
    "            loss = 0\n",
    "            \n",
    "            batches = len(minibatches)\n",
    "            #Here I iterate over the batches\n",
    "            for batch in range(batches):\n",
    "            #### Here I train the RNNcell\n",
    "            #### The X is the time series, the Y is shifted by 1 time step\n",
    "                train, target = minibatches[batch]\n",
    "                self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n",
    "                \n",
    "            \n",
    "                loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n",
    "            \n",
    "            #Normalize loss over number of batches and scale it back before normaliziation\n",
    "            loss /= batches\n",
    "            self.loss_list.append(loss)\n",
    "            \n",
    "            #print(loss)\n",
    "            \n",
    "            #Here I create the checkpoint if the perfomance is better\n",
    "            if iep > 1 and iep%checkpoint == 0 and self.loss_list[iep] < self.loss_list[epoche_save]:\n",
    "                #print(\"Checkpoint created at epoch: \", iep)\n",
    "                self.save(folder)\n",
    "                epoche_save = iep\n",
    "            \n",
    "            #early stopping with patience\n",
    "            if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 2/1000000:\n",
    "                patience_cnt += 1\n",
    "                #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n",
    "                \n",
    "                if patience_cnt + 1 > patience:\n",
    "                    print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n",
    "                    print(\"Cost: \",loss)\n",
    "                    break\n",
    "            \n",
    "            #Note that the loss here is multiplied with 1000 for easier reading\n",
    "            if iep%print_step==0:\n",
    "                print(\"Epoch number \",iep)\n",
    "                print(\"Cost: \",loss*1000, \"e-3\")\n",
    "                print(\"Patience: \",patience_cnt, \"/\", patience)\n",
    "                print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n",
    "        \n",
    "        #Set model back to the last checkpoint if performance was better\n",
    "        if self.loss_list[epoche_save] < self.loss_list[iep]:\n",
    "            self.load(folder)\n",
    "            print(\"\\n\", \"Last checkpoint at epoch \", epoche_save, \" loaded\")\n",
    "            print(\"Performance at last checkpoint is \" ,self.loss_list[iep] - self.loss_list[epoche_save], \" better\" )\n",
    "            \n",
    "            \n",
    "        \n",
    "    def predict(self, x):\n",
    "        return self.sess.run(self.output, feed_dict={self.X:x})\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps = 7\n",
    "future_steps = 1\n",
    "\n",
    "ninputs = 3\n",
    "\n",
    "#ncells as int or list of int\n",
    "ncells = [50, 40, 30, 20, 10]\n",
    "\n",
    "num_output = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn.set_cost_and_functions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch number  0\n",
      "Cost:  3770.231458734959 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  0 \n",
      "\n",
      "Epoch number  5\n",
      "Cost:  1649.7736788810569 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  5 \n",
      "\n",
      "Epoch number  10\n",
      "Cost:  625.2868418046768 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  10 \n",
      "\n",
      "Epoch number  15\n",
      "Cost:  294.9610768639027 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  15 \n",
      "\n",
      "Epoch number  20\n",
      "Cost:  209.0108957379422 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  20 \n",
      "\n",
      "Epoch number  25\n",
      "Cost:  174.1866168982171 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  25 \n",
      "\n",
      "Epoch number  30\n",
      "Cost:  149.8719225538538 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  30 \n",
      "\n",
      "Epoch number  35\n",
      "Cost:  131.33942407179387 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  35 \n",
      "\n",
      "Epoch number  40\n",
      "Cost:  115.83642023516462 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  40 \n",
      "\n",
      "Epoch number  45\n",
      "Cost:  107.55172256935151 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  45 \n",
      "\n",
      "Epoch number  50\n",
      "Cost:  98.54952309359895 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  50 \n",
      "\n",
      "Epoch number  55\n",
      "Cost:  95.66065657170529 e4\n",
      "Patience:  0 / 200\n",
      "Last checkpoint at: Epoch  55 \n",
      "\n",
      "Epoch number  60\n",
      "Cost:  90.34742145462239 e4\n",
      "Patience:  1 / 200\n",
      "Last checkpoint at: Epoch  60 \n",
      "\n",
      "Epoch number  65\n",
      "Cost:  84.77292855844853 e4\n",
      "Patience:  2 / 200\n",
      "Last checkpoint at: Epoch  65 \n",
      "\n",
      "Epoch number  70\n",
      "Cost:  78.54001398416275 e4\n",
      "Patience:  3 / 200\n",
      "Last checkpoint at: Epoch  70 \n",
      "\n",
      "Epoch number  75\n",
      "Cost:  75.23123551397882 e4\n",
      "Patience:  3 / 200\n",
      "Last checkpoint at: Epoch  75 \n",
      "\n",
      "Epoch number  80\n",
      "Cost:  73.33986362085697 e4\n",
      "Patience:  4 / 200\n",
      "Last checkpoint at: Epoch  80 \n",
      "\n",
      "Epoch number  85\n",
      "Cost:  69.12997319422504 e4\n",
      "Patience:  5 / 200\n",
      "Last checkpoint at: Epoch  85 \n",
      "\n",
      "Epoch number  90\n",
      "Cost:  65.79162087291479 e4\n",
      "Patience:  5 / 200\n",
      "Last checkpoint at: Epoch  90 \n",
      "\n",
      "Epoch number  95\n",
      "Cost:  61.82488113483216 e4\n",
      "Patience:  6 / 200\n",
      "Last checkpoint at: Epoch  95 \n",
      "\n",
      "Epoch number  100\n",
      "Cost:  59.33671109774646 e4\n",
      "Patience:  8 / 200\n",
      "Last checkpoint at: Epoch  100 \n",
      "\n",
      "Epoch number  105\n",
      "Cost:  57.19678456637453 e4\n",
      "Patience:  9 / 200\n",
      "Last checkpoint at: Epoch  105 \n",
      "\n",
      "Epoch number  110\n",
      "Cost:  55.66507161773266 e4\n",
      "Patience:  10 / 200\n",
      "Last checkpoint at: Epoch  110 \n",
      "\n",
      "Epoch number  115\n",
      "Cost:  54.365597526602286 e4\n",
      "Patience:  13 / 200\n",
      "Last checkpoint at: Epoch  115 \n",
      "\n",
      "Epoch number  120\n",
      "Cost:  52.487826807067755 e4\n",
      "Patience:  14 / 200\n",
      "Last checkpoint at: Epoch  120 \n",
      "\n",
      "Epoch number  125\n",
      "Cost:  51.60155072015651 e4\n",
      "Patience:  17 / 200\n",
      "Last checkpoint at: Epoch  125 \n",
      "\n",
      "Epoch number  130\n",
      "Cost:  51.004822227232 e4\n",
      "Patience:  20 / 200\n",
      "Last checkpoint at: Epoch  130 \n",
      "\n",
      "Epoch number  135\n",
      "Cost:  49.656663347590474 e4\n",
      "Patience:  22 / 200\n",
      "Last checkpoint at: Epoch  135 \n",
      "\n",
      "Epoch number  140\n",
      "Cost:  49.04315717756114 e4\n",
      "Patience:  26 / 200\n",
      "Last checkpoint at: Epoch  140 \n",
      "\n",
      "Epoch number  145\n",
      "Cost:  48.333713487583275 e4\n",
      "Patience:  29 / 200\n",
      "Last checkpoint at: Epoch  145 \n",
      "\n",
      "Epoch number  150\n",
      "Cost:  47.4689517447606 e4\n",
      "Patience:  33 / 200\n",
      "Last checkpoint at: Epoch  150 \n",
      "\n",
      "Epoch number  155\n",
      "Cost:  46.82262457827938 e4\n",
      "Patience:  38 / 200\n",
      "Last checkpoint at: Epoch  155 \n",
      "\n",
      "Epoch number  160\n",
      "Cost:  46.189470573308625 e4\n",
      "Patience:  43 / 200\n",
      "Last checkpoint at: Epoch  160 \n",
      "\n",
      "Epoch number  165\n",
      "Cost:  45.566867759570165 e4\n",
      "Patience:  48 / 200\n",
      "Last checkpoint at: Epoch  165 \n",
      "\n",
      "Epoch number  170\n",
      "Cost:  45.00874754120695 e4\n",
      "Patience:  53 / 200\n",
      "Last checkpoint at: Epoch  170 \n",
      "\n",
      "Epoch number  175\n",
      "Cost:  44.46649339367101 e4\n",
      "Patience:  58 / 200\n",
      "Last checkpoint at: Epoch  175 \n",
      "\n",
      "Epoch number  180\n",
      "Cost:  43.92929008587244 e4\n",
      "Patience:  63 / 200\n",
      "Last checkpoint at: Epoch  180 \n",
      "\n",
      "Epoch number  185\n",
      "Cost:  43.44754183585656 e4\n",
      "Patience:  68 / 200\n",
      "Last checkpoint at: Epoch  185 \n",
      "\n",
      "Epoch number  190\n",
      "Cost:  42.95319576371223 e4\n",
      "Patience:  73 / 200\n",
      "Last checkpoint at: Epoch  190 \n",
      "\n",
      "Epoch number  195\n",
      "Cost:  42.52819289909082 e4\n",
      "Patience:  78 / 200\n",
      "Last checkpoint at: Epoch  195 \n",
      "\n",
      "Epoch number  200\n",
      "Cost:  41.93341770665126 e4\n",
      "Patience:  83 / 200\n",
      "Last checkpoint at: Epoch  200 \n",
      "\n",
      "Epoch number  205\n",
      "Cost:  41.554861285902085 e4\n",
      "Patience:  88 / 200\n",
      "Last checkpoint at: Epoch  205 \n",
      "\n",
      "Epoch number  210\n",
      "Cost:  41.090038733834284 e4\n",
      "Patience:  93 / 200\n",
      "Last checkpoint at: Epoch  210 \n",
      "\n",
      "Epoch number  215\n",
      "Cost:  40.845294889221165 e4\n",
      "Patience:  98 / 200\n",
      "Last checkpoint at: Epoch  215 \n",
      "\n",
      "Epoch number  220\n",
      "Cost:  40.25109122170412 e4\n",
      "Patience:  103 / 200\n",
      "Last checkpoint at: Epoch  220 \n",
      "\n",
      "Epoch number  225\n",
      "Cost:  39.58158948002977 e4\n",
      "Patience:  108 / 200\n",
      "Last checkpoint at: Epoch  225 \n",
      "\n",
      "Epoch number  230\n",
      "Cost:  38.97598008327979 e4\n",
      "Patience:  113 / 200\n",
      "Last checkpoint at: Epoch  230 \n",
      "\n",
      "Epoch number  235\n",
      "Cost:  38.51150915502234 e4\n",
      "Patience:  118 / 200\n",
      "Last checkpoint at: Epoch  235 \n",
      "\n",
      "Epoch number  240\n",
      "Cost:  38.299499218292695 e4\n",
      "Patience:  123 / 200\n",
      "Last checkpoint at: Epoch  240 \n",
      "\n",
      "Epoch number  245\n",
      "Cost:  37.74655878821269 e4\n",
      "Patience:  128 / 200\n",
      "Last checkpoint at: Epoch  245 \n",
      "\n",
      "Epoch number  250\n",
      "Cost:  37.40582783567778 e4\n",
      "Patience:  133 / 200\n",
      "Last checkpoint at: Epoch  250 \n",
      "\n",
      "Epoch number  255\n",
      "Cost:  37.24810196720856 e4\n",
      "Patience:  138 / 200\n",
      "Last checkpoint at: Epoch  255 \n",
      "\n",
      "Epoch number  260\n",
      "Cost:  37.280498320197175 e4\n",
      "Patience:  143 / 200\n",
      "Last checkpoint at: Epoch  255 \n",
      "\n",
      "Epoch number  265\n",
      "Cost:  36.25094043487247 e4\n",
      "Patience:  147 / 200\n",
      "Last checkpoint at: Epoch  265 \n",
      "\n",
      "Epoch number  270\n",
      "Cost:  36.03106825315255 e4\n",
      "Patience:  152 / 200\n",
      "Last checkpoint at: Epoch  270 \n",
      "\n",
      "Epoch number  275\n",
      "Cost:  35.67509779191398 e4\n",
      "Patience:  156 / 200\n",
      "Last checkpoint at: Epoch  275 \n",
      "\n",
      "Epoch number  280\n",
      "Cost:  35.42137842506487 e4\n",
      "Patience:  161 / 200\n",
      "Last checkpoint at: Epoch  280 \n",
      "\n",
      "Epoch number  285\n",
      "Cost:  35.79035718390282 e4\n",
      "Patience:  164 / 200\n",
      "Last checkpoint at: Epoch  280 \n",
      "\n",
      "Epoch number  290\n",
      "Cost:  33.758991754594 e4\n",
      "Patience:  165 / 200\n",
      "Last checkpoint at: Epoch  290 \n",
      "\n",
      "Epoch number  295\n",
      "Cost:  34.39420328891658 e4\n",
      "Patience:  166 / 200\n",
      "Last checkpoint at: Epoch  290 \n",
      "\n",
      "Epoch number  300\n",
      "Cost:  33.66679522862777 e4\n",
      "Patience:  166 / 200\n",
      "Last checkpoint at: Epoch  300 \n",
      "\n",
      "Epoch number  305\n",
      "Cost:  34.23552023880976 e4\n",
      "Patience:  167 / 200\n",
      "Last checkpoint at: Epoch  300 \n",
      "\n",
      "Epoch number  310\n",
      "Cost:  33.27848409560132 e4\n",
      "Patience:  168 / 200\n",
      "Last checkpoint at: Epoch  310 \n",
      "\n",
      "Epoch number  315\n",
      "Cost:  32.72916789741275 e4\n",
      "Patience:  171 / 200\n",
      "Last checkpoint at: Epoch  315 \n",
      "\n",
      "Epoch number  320\n",
      "Cost:  32.42362023113255 e4\n",
      "Patience:  173 / 200\n",
      "Last checkpoint at: Epoch  320 \n",
      "\n",
      "Epoch number  325\n",
      "Cost:  33.13556412591579 e4\n",
      "Patience:  173 / 200\n",
      "Last checkpoint at: Epoch  320 \n",
      "\n",
      "Epoch number  330\n",
      "Cost:  34.35548811041294 e4\n",
      "Patience:  173 / 200\n",
      "Last checkpoint at: Epoch  320 \n",
      "\n",
      "Epoch number  335\n",
      "Cost:  31.17884152588692 e4\n",
      "Patience:  174 / 200\n",
      "Last checkpoint at: Epoch  335 \n",
      "\n",
      "Epoch number  340\n",
      "Cost:  33.64366251341206 e4\n",
      "Patience:  174 / 200\n",
      "Last checkpoint at: Epoch  335 \n",
      "\n",
      "Epoch number  345\n",
      "Cost:  32.388941939682404 e4\n",
      "Patience:  175 / 200\n",
      "Last checkpoint at: Epoch  335 \n",
      "\n",
      "Epoch number  350\n",
      "Cost:  29.8897856648298 e4\n",
      "Patience:  175 / 200\n",
      "Last checkpoint at: Epoch  350 \n",
      "\n",
      "Epoch number  355\n",
      "Cost:  30.779531522792706 e4\n",
      "Patience:  176 / 200\n",
      "Last checkpoint at: Epoch  350 \n",
      "\n",
      "Epoch number  360\n",
      "Cost:  32.77950439641767 e4\n",
      "Patience:  177 / 200\n",
      "Last checkpoint at: Epoch  350 \n",
      "\n",
      "Epoch number  365\n",
      "Cost:  34.279519781232516 e4\n",
      "Patience:  177 / 200\n",
      "Last checkpoint at: Epoch  350 \n",
      "\n",
      "Epoch number  370\n",
      "Cost:  29.02430596147129 e4\n",
      "Patience:  177 / 200\n",
      "Last checkpoint at: Epoch  370 \n",
      "\n",
      "Epoch number  375\n",
      "Cost:  31.375054398828997 e4\n",
      "Patience:  178 / 200\n",
      "Last checkpoint at: Epoch  370 \n",
      "\n",
      "Epoch number  380\n",
      "Cost:  33.813590144223355 e4\n",
      "Patience:  178 / 200\n",
      "Last checkpoint at: Epoch  370 \n",
      "\n",
      "Epoch number  385\n",
      "Cost:  28.6719871268786 e4\n",
      "Patience:  178 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  390\n",
      "Cost:  31.848519872081408 e4\n",
      "Patience:  179 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  395\n",
      "Cost:  29.007866582337847 e4\n",
      "Patience:  181 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  400\n",
      "Cost:  33.16965553552863 e4\n",
      "Patience:  181 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  405\n",
      "Cost:  32.650657305295795 e4\n",
      "Patience:  181 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  410\n",
      "Cost:  28.816359365319318 e4\n",
      "Patience:  181 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch number  415\n",
      "Cost:  29.141941761716886 e4\n",
      "Patience:  181 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  420\n",
      "Cost:  30.577135856877614 e4\n",
      "Patience:  182 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  425\n",
      "Cost:  29.400000695456217 e4\n",
      "Patience:  183 / 200\n",
      "Last checkpoint at: Epoch  385 \n",
      "\n",
      "Epoch number  430\n",
      "Cost:  26.99479599423865 e4\n",
      "Patience:  183 / 200\n",
      "Last checkpoint at: Epoch  430 \n",
      "\n",
      "Epoch number  435\n",
      "Cost:  30.304402994744958 e4\n",
      "Patience:  184 / 200\n",
      "Last checkpoint at: Epoch  430 \n",
      "\n",
      "Epoch number  440\n",
      "Cost:  29.647010675770172 e4\n",
      "Patience:  184 / 200\n",
      "Last checkpoint at: Epoch  430 \n",
      "\n",
      "Epoch number  445\n",
      "Cost:  27.00613232012442 e4\n",
      "Patience:  185 / 200\n",
      "Last checkpoint at: Epoch  430 \n",
      "\n",
      "Epoch number  450\n",
      "Cost:  27.036350567210864 e4\n",
      "Patience:  186 / 200\n",
      "Last checkpoint at: Epoch  430 \n",
      "\n",
      "Epoch number  455\n",
      "Cost:  27.08697458729148 e4\n",
      "Patience:  187 / 200\n",
      "Last checkpoint at: Epoch  430 \n",
      "\n",
      "Epoch number  460\n",
      "Cost:  28.004820329791055 e4\n",
      "Patience:  188 / 200\n",
      "Last checkpoint at: Epoch  430 \n",
      "\n",
      "Epoch number  465\n",
      "Cost:  26.3666685551722 e4\n",
      "Patience:  188 / 200\n",
      "Last checkpoint at: Epoch  465 \n",
      "\n",
      "Epoch number  470\n",
      "Cost:  26.36444576560183 e4\n",
      "Patience:  188 / 200\n",
      "Last checkpoint at: Epoch  470 \n",
      "\n",
      "Epoch number  475\n",
      "Cost:  31.123574119695324 e4\n",
      "Patience:  188 / 200\n",
      "Last checkpoint at: Epoch  470 \n",
      "\n",
      "Epoch number  480\n",
      "Cost:  27.53822227068087 e4\n",
      "Patience:  189 / 200\n",
      "Last checkpoint at: Epoch  470 \n",
      "\n",
      "Epoch number  485\n",
      "Cost:  26.472763485334657 e4\n",
      "Patience:  189 / 200\n",
      "Last checkpoint at: Epoch  470 \n",
      "\n",
      "Epoch number  490\n",
      "Cost:  25.98736776990142 e4\n",
      "Patience:  190 / 200\n",
      "Last checkpoint at: Epoch  490 \n",
      "\n",
      "Epoch number  495\n",
      "Cost:  25.32091308781441 e4\n",
      "Patience:  191 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  500\n",
      "Cost:  26.51548171614079 e4\n",
      "Patience:  191 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  505\n",
      "Cost:  25.78474184934129 e4\n",
      "Patience:  191 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  510\n",
      "Cost:  26.016250708477294 e4\n",
      "Patience:  191 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  515\n",
      "Cost:  28.13248825754891 e4\n",
      "Patience:  191 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  520\n",
      "Cost:  28.441735156910852 e4\n",
      "Patience:  191 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  525\n",
      "Cost:  25.8854781079324 e4\n",
      "Patience:  193 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  530\n",
      "Cost:  25.448204473929202 e4\n",
      "Patience:  193 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  535\n",
      "Cost:  26.26546668483222 e4\n",
      "Patience:  193 / 200\n",
      "Last checkpoint at: Epoch  495 \n",
      "\n",
      "Epoch number  540\n",
      "Cost:  24.608338271525312 e4\n",
      "Patience:  196 / 200\n",
      "Last checkpoint at: Epoch  540 \n",
      "\n",
      "Epoch number  545\n",
      "Cost:  25.521852422822665 e4\n",
      "Patience:  196 / 200\n",
      "Last checkpoint at: Epoch  540 \n",
      "\n",
      "Epoch number  550\n",
      "Cost:  24.915404786217085 e4\n",
      "Patience:  198 / 200\n",
      "Last checkpoint at: Epoch  540 \n",
      "\n",
      "Epoch number  555\n",
      "Cost:  25.868487217404105 e4\n",
      "Patience:  198 / 200\n",
      "Last checkpoint at: Epoch  540 \n",
      "\n",
      "Epoch number  560\n",
      "Cost:  27.24954412576366 e4\n",
      "Patience:  199 / 200\n",
      "Last checkpoint at: Epoch  540 \n",
      "\n",
      "\n",
      " Early stopping at epoch  565 , difference:  2.3366942843223992e-05\n",
      "Cost:  0.002444114783739156\n"
     ]
    }
   ],
   "source": [
    "rnn.fit(minibatches, epochs = 5000, print_step=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(rnn.loss_list)\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Cost\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#save in a folder that describes the model\n",
    "folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "rnn.save(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm_leaky_relu_5l_[50,40,30,20,10]c/rnn_basic\n"
     ]
    }
   ],
   "source": [
    "#folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n",
    "#rnn.load(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "###test_input.shape###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here I predict based on my test set\n",
    "\n",
    "test_pred = rnn.predict(test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "operands could not be broadcast together with shapes (469,21) (24,) (469,21) ",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-41-1a19da3ab328>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmin_max_scaler\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_input\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[1;31m#print(min_max_scaler_inv(test_pred)-min_max_scaler_inv(test_target))\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-8-ab3cbe9c0a7e>\u001b[0m in \u001b[0;36mmin_max_scaler\u001b[1;34m(arr, min_max_scalor)\u001b[0m\n\u001b[0;32m     13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     14\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 15\u001b[1;33m         \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mreshapor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmin_max_scalor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mreshapor_inv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     16\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     17\u001b[0m         \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmin_max_scalor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\sklearn\\preprocessing\\data.py\u001b[0m in \u001b[0;36mtransform\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m    367\u001b[0m         \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mFLOAT_DTYPES\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    368\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 369\u001b[1;33m         \u001b[0mX\u001b[0m \u001b[1;33m*=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mscale_\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    370\u001b[0m         \u001b[0mX\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmin_\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    371\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mValueError\u001b[0m: operands could not be broadcast together with shapes (469,21) (24,) (469,21) "
     ]
    }
   ],
   "source": [
    "#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n",
    "min_max_scaler(test_input)\n",
    "\n",
    "\n",
    "#print(min_max_scaler_inv(test_pred)-min_max_scaler_inv(test_target))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Here I evaluate my model on the test set based on mean_squared_error\n",
    "\n",
    "rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}