{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib as mpl\n", "import random\n", "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops\n", "from sklearn import preprocessing\n", "import pickle as pkl\n", "from pathlib import Path\n", "\n", "#import seaborn as sns" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#import data as array\n", "# 8 hits with x,y,z\n", "\n", "testset = pd.read_pickle('matched_8hittracks.pkl')\n", "#print(testset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#Check testset with arbitrary particle\n", "\n", "tset = np.array(testset)\n", "tset = tset.astype('float32')\n", "#print(tset.shape)\n", "#for i in range(8):\n", " #print(tset[1,3*i:(3*i+3)])\n", "#print(tset[0,:])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n", "\n", "def reshapor(arr_orig):\n", " timesteps = int(arr_orig.shape[1]/3)\n", " number_examples = int(arr_orig.shape[0])\n", " arr = np.zeros((number_examples, timesteps, 3))\n", " \n", " for i in range(number_examples):\n", " for t in range(timesteps):\n", " arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n", " \n", " return arr\n", "\n", "def reshapor_inv(array_shaped):\n", " timesteps = int(array_shaped.shape[1])\n", " num_examples = int(array_shaped.shape[0])\n", " arr = np.zeros((num_examples, timesteps*3))\n", " \n", " for i in range(num_examples):\n", " for t in range(timesteps):\n", " arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "### create the training set and the test set###\n", "\n", "def create_random_sets(dataset, train_to_total_ratio):\n", " #shuffle the dataset\n", " num_examples = dataset.shape[0]\n", " p = np.random.permutation(num_examples)\n", " dataset = dataset[p,:]\n", " \n", " #evaluate siye of training and test set and initialize them\n", " train_set_size = np.int(num_examples*train_to_total_ratio)\n", " test_set_size = num_examples - train_set_size\n", " \n", " train_set = np.zeros((train_set_size, dataset.shape[1]))\n", " test_set = np.zeros((test_set_size, dataset.shape[1]))\n", " \n", "\n", " #fill train and test sets\n", " for i in range(num_examples):\n", " if train_set_size > i:\n", " train_set[i,:] += dataset[i,:]\n", " else:\n", " test_set[i - train_set_size,:] += dataset[i,:]\n", " \n", " return train_set, test_set\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_set, test_set = create_random_sets(tset, 0.99)\n", "\n", "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n", "#print(test_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "#Normalize the data advanced version with scikit learn\n", "\n", "#set the transormation based on training set\n", "def set_min_max_scaler(arr, feature_range= (-1,1)):\n", " min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n", " if len(arr.shape) == 3:\n", " arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr))) \n", " else:\n", " arr = min_max_scalor.fit_transform(arr)\n", " return min_max_scalor\n", "\n", "min_max_scalor = set_min_max_scaler(train_set)\n", "\n", "\n", "#transform data\n", "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(min_max_scalor.transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = min_max_scalor.transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr\n", " \n", "#inverse transformation\n", "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = min_max_scalor.inverse_transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.nverse_transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "#Normalize the data advanced version with scikit learn - Standard scaler\n", "\n", "#set the transormation based on training set\n", "def set_std_scaler(arr):\n", " std_scalor = preprocessing.StandardScaler()\n", " if len(arr.shape) == 3:\n", " arr = reshapor(std_scalor.fit(reshapor_inv(arr))) \n", " else:\n", " arr = std_scalor.fit(arr)\n", " return std_scalor\n", "\n", "std_scalor = set_std_scaler(train_set)\n", "\n", "#transform data\n", "def std_scaler(arr, std_scalor= std_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(std_scalor.transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = std_scalor.transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr\n", " \n", "#inverse transformation\n", "def std_scaler_inv(arr, std_scalor= std_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(std_scalor.inverse_transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = std_scalor.inverse_transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "#reshape the data\n", "\n", "train_set = reshapor(train_set)\n", "test_set = reshapor(test_set)\n", "\n", "#print(train_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "#Scale data either with MinMax scaler or with Standard scaler\n", "#Return scalor if fit = True and and scaled array otherwise\n", "\n", "def scaler(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n", " \n", " if scalerfunc == \"std\":\n", " arr = std_scaler(arr, std_scalor= std_scalor)\n", " return arr\n", " \n", " elif scalerfunc == \"minmax\":\n", " arr = min_max_scaler(arr, min_max_scalor= min_max_scalor)\n", " return arr\n", " \n", " else:\n", " raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n", "\n", "def scaler_inv(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n", "\n", " if scalerfunc == \"std\":\n", " arr = std_scaler_inv(arr, std_scalor= std_scalor)\n", " return arr\n", " \n", " elif scalerfunc == \"minmax\":\n", " arr = min_max_scaler_inv(arr, min_max_scalor= std_scalor)\n", " return arr\n", " \n", " else:\n", " raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "#scale the data\n", "\n", "func = \"minmax\"\n", "\n", "train_set = scaler(train_set, scalerfunc = func)\n", "test_set = scaler(test_set, scalerfunc = func)\n", "\n", "if func == \"minmax\":\n", " scalor = min_max_scalor\n", "elif func == \"std\":\n", " scalor = std_scalor\n", "\n", "#print(train_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "###create random mini_batches###\n", "\n", "\n", "def unison_shuffled_copies(a, b):\n", " assert a.shape[0] == b.shape[0]\n", " p = np.random.permutation(a.shape[0])\n", " return a[p,:,:], b[p,:,:]\n", "\n", "def random_mini_batches(inputt, target, minibatch_size = 500):\n", " \n", " num_examples = inputt.shape[0]\n", " \n", " \n", " #Number of complete batches\n", " \n", " number_of_batches = int(num_examples/minibatch_size)\n", " minibatches = []\n", " \n", " #shuffle particles\n", " _i, _t = unison_shuffled_copies(inputt, target)\n", " #print(_t.shape)\n", " \n", " \n", " for i in range(number_of_batches):\n", " \n", " minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatches.append((minibatch_train, minibatch_true))\n", " \n", " \n", " minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n", " \n", " \n", " return minibatches\n", " " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "#Create random minibatches of train and test set with input and target array\n", "\n", "\n", "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n", "#_train, _target = minibatches[0]\n", "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n", "#print(train[0,:,:], target[0,:,:])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "#minibatches = random_mini_batches(inputt_train, target_train)\n", "\n", "\n", "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n", "\n", "#print(len(minibatches))\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class RNNPlacePrediction():\n", " \n", " \n", " def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\", scalor= scalor):\n", " \n", " self.nsteps = time_steps\n", " self.future_steps = future_steps\n", " self.ninputs = ninputs\n", " self.ncells = ncells\n", " self.num_output = num_output\n", " self._ = cell_type #later used to create folder name\n", " self.__ = activation #later used to create folder name\n", " self.loss_list = []\n", " self.scalor = scalor\n", " \n", " #### The input is of shape (num_examples, time_steps, ninputs)\n", " #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n", " self.X = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n", " self.Y = tf.placeholder(dtype=tf.float32, shape=(None, self.nsteps, ninputs))\n", "\n", " \n", " #Check if activation function valid and set activation\n", " if self.__==\"relu\":\n", " self.activation = tf.nn.relu\n", " \n", " elif self.__==\"tanh\":\n", " self.activation = tf.nn.tanh\n", " \n", " elif self.__==\"leaky_relu\":\n", " self.activation = tf.nn.leaky_relu\n", " \n", " elif self.__==\"elu\":\n", " self.activation = tf.nn.elu\n", " \n", " else:\n", " raise ValueError(\"Wrong rnn avtivation function: {}\".format(self.__))\n", " \n", " \n", " \n", " #Check if cell type valid and set cell_type\n", " if self._==\"basic_rnn\":\n", " self.cell_type = tf.contrib.rnn.BasicRNNCell\n", " \n", " elif self._==\"lstm\":\n", " self.cell_type = tf.contrib.rnn.BasicLSTMCell\n", " \n", " elif self._==\"GRU\":\n", " self.cell_type = tf.contrib.rnn.GRUCell\n", " \n", " else:\n", " raise ValueError(\"Wrong rnn cell type: {}\".format(self._))\n", " \n", " \n", " #Check Input of ncells \n", " if (type(self.ncells) == int):\n", " self.ncells = [self.ncells]\n", " \n", " if (type(self.ncells) != list):\n", " raise ValueError(\"Wrong type of Input for ncells\")\n", " \n", " for _ in range(len(self.ncells)):\n", " if type(self.ncells[_]) != int:\n", " raise ValueError(\"Wrong type of Input for ncells\")\n", " \n", " self.activationlist = []\n", " for _ in range(len(self.ncells)-1):\n", " self.activationlist.append(self.activation)\n", " self.activationlist.append(tf.nn.tanh)\n", " \n", " self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n", " for layer in range(len(self.ncells))])\n", " \n", " \n", " #### I now define the output\n", " self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n", " \n", " \n", " \n", " \n", " \n", " self.sess = tf.Session()\n", " \n", " def set_cost_and_functions(self, LR=0.001):\n", " #### I define here the function that unrolls the RNN cell\n", " self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n", " #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n", " self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output)) \n", " \n", " #### the rest proceed as usual\n", " self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n", " #### Variable initializer\n", " self.init = tf.global_variables_initializer()\n", " self.saver = tf.train.Saver()\n", " self.sess.run(self.init)\n", " \n", " \n", " def save(self, rnn_folder=\"./rnn_model/rnn_basic\"):\n", " self.saver.save(self.sess, rnn_folder) \n", " \n", " \n", " def load(self, filename=\"./rnn_model/rnn_basic\"):\n", " self.saver.restore(self.sess, filename)\n", "\n", " \n", " \n", " def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n", " patience_cnt = 0\n", " start = len(self.loss_list)\n", " epoche_save = start\n", " \n", " folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n", " \n", " for iep in range(start, start + epochs):\n", " loss = 0\n", " \n", " batches = len(minibatches)\n", " #Here I iterate over the batches\n", " for batch in range(batches):\n", " #### Here I train the RNNcell\n", " #### The X is the time series, the Y is shifted by 1 time step\n", " train, target = minibatches[batch]\n", " self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n", " \n", " \n", " loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n", " \n", " #Normalize loss over number of batches and scale it back before normaliziation\n", " loss /= batches\n", " self.loss_list.append(loss)\n", " \n", " #print(loss)\n", " \n", " #Here I create the checkpoint if the perfomance is better\n", " if iep > 1 and iep%checkpoint == 0 and self.loss_list[iep] < self.loss_list[epoche_save]:\n", " #print(\"Checkpoint created at epoch: \", iep)\n", " self.save(folder)\n", " epoche_save = iep\n", " \n", " #early stopping with patience\n", " if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 1.5/10**7:\n", " patience_cnt += 1\n", " #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n", " \n", " if patience_cnt + 1 > patience:\n", " print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n", " print(\"Cost: \",loss)\n", " break\n", " \n", " #Note that the loss here is multiplied with 1000 for easier reading\n", " if iep%print_step==0:\n", " print(\"Epoch number \",iep)\n", " print(\"Cost: \",loss*10**6, \"e-6\")\n", " print(\"Patience: \",patience_cnt, \"/\", patience)\n", " print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n", " \n", " #Set model back to the last checkpoint if performance was better\n", " if self.loss_list[epoche_save] < self.loss_list[iep]:\n", " self.load(folder)\n", " print(\"\\n\")\n", " print(\"State of last checkpoint checkpoint at epoch \", epoche_save, \" restored\")\n", " print(\"Performance at last checkpoint is \" ,(self.loss_list[iep] - self.loss_list[epoche_save])/self.loss_list[iep]*100, \"% better\" )\n", " \n", " folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", " self.save(folder)\n", " print(\"\\n\")\n", " print(\"Model saved in at: \", folder)\n", " \n", " \n", " \n", " def predict(self, x):\n", " return self.sess.run(self.output, feed_dict={self.X:x})\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "#saves the rnn model and all its parameters including the scaler used\n", "#optional also saves the minibatches used to train and the test set\n", "\n", "def full_save(rnn, train= True, test= True):\n", " folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", " rnn.save(folder)\n", " pkl_name = folder[2:-10] + \".pkl\"\n", " \n", " \n", " pkl_dic = {\"ncells\": rnn.ncells,\n", " \"ninputs\": rnn.ninputs,\n", " \"future_steps\": rnn.future_steps,\n", " \"nsteps\": rnn.nsteps,\n", " \"num_output\": rnn.num_output,\n", " \"cell_type\": rnn._, #cell_type\n", " \"activation\": rnn.__, #Activation\n", " \"loss_list\": rnn.loss_list,\n", " \"scalor\": rnn.scalor}\n", " \n", " if train == True:\n", " pkl_dic[\"minibatches\"] = minibatches\n", " \n", " if test == True:\n", " pkl_dic[\"test_input\"] = test_input\n", " pkl_dic[\"test_target\"] = test_target\n", " \n", " pkl.dump( pkl_dic, open(pkl_name , \"wb\" ) )\n", " \n", " print(\"Model saved at: \", folder)\n", " print(\"Remaining data saved as: {}\".format(pkl_name))\n", "\n", "\n", "\n", "#loads the rnn model with all its parameters including the scaler used\n", "#Checks if the pkl data also contains the training or test sets an return them accordingly\n", "def full_load(folder): \n", " #returns state of rnn with all information and returns the train and test set used\n", " \n", " #Directory of pkl file\n", " pkl_name = folder[2:-10] + \".pkl\"\n", " \n", " #Check if pkl file exists\n", " my_file = Path(pkl_name)\n", " if my_file.is_file() == False:\n", " raise ValueError(\"There is no .pkl file with the name: {}\".format(pkl_name))\n", " \n", " pkl_dic = pkl.load( open(pkl_name , \"rb\" ) )\n", " ncells = pkl_dic[\"ncells\"]\n", " ninputs = pkl_dic[\"ninputs\"]\n", " scalor = pkl_dic[\"scalor\"]\n", " future_steps = pkl_dic[\"future_steps\"]\n", " timesteps = pkl_dic[\"nsteps\"] \n", " num_output = pkl_dic[\"num_output\"]\n", " cell_type = pkl_dic[\"cell_type\"]\n", " activation = pkl_dic[\"activation\"]\n", " \n", " #Check if test or trainng set in dictionary\n", " batch = False\n", " test = False\n", " if \"minibatches\" in pkl_dic:\n", " batch = True\n", " minibatches = pkl_dic[\"minibatches\"]\n", " if \"test_input\" in pkl_dic:\n", " test = True\n", " test_input = [\"test_input\"]\n", " test_target = [\"test_target\"]\n", " \n", " #loads and initializes a new model with the exact same properties\n", " \n", " tf.reset_default_graph()\n", " rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n", " ncells=ncells, num_output=num_output, cell_type=cell_type, activation=activation, scalor=scalor)\n", "\n", " rnn.set_cost_and_functions()\n", " \n", " rnn.load(folder)\n", " \n", " rnn.loss_list = pkl_dic[\"loss_list\"]\n", " \n", " print(\"Model succesfully loaded\")\n", " \n", " if batch and test:\n", " data = [minibatches, test_input, test_target]\n", " print(\"Minibatches (=training data) and test_input and test_target in data loaded\")\n", " return rnn, data\n", " \n", " elif batch:\n", " data = [minibatches]\n", " print(\"Minibatches (=training data) loaded in data\")\n", " return rnn, data\n", " \n", " elif test:\n", " data = [test_input, test_target]\n", " print(\"test_input and test_target loaded in data\")\n", " return rnn, data\n", " \n", " else:\n", " data = []\n", " print(\"Only Model restored, no trainig or test data found in {}\".format(pkl_name))\n", " print(\"Returned data is empty!\")\n", " return rnn, data\n", " \n", " \n", "#returns the folder name used by full_save and full_load for a given architecture\n", "def get_rnn_folder(ncells, cell_type, activation):\n", " folder = \"./rnn_model_\" + cell_type + \"_\" + activation + \"_\" + str(ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", " return folder" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "timesteps = 7\n", "future_steps = 1\n", "\n", "ninputs = 3\n", "\n", "#ncells as int or list of int\n", "ncells = [50, 40, 30, 20, 10]\n", "\n", "num_output = 3" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use the retry module or similar alternatives.\n" ] } ], "source": [ "tf.reset_default_graph()\n", "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n", " ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "rnn.set_cost_and_functions()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch number 0\n", "Cost: 376274.61996484315 e-6\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 0 \n", "\n", "\n", "\n", "Model saved in at: ./rnn_model_lstm_leaky_relu_[50,40,30,20,10]c/rnn_basic\n", "Model saved at: ./rnn_model_lstm_leaky_relu_[50,40,30,20,10]c/rnn_basic\n", "Remaining data saved as: rnn_model_lstm_leaky_relu_[50,40,30,20,10]c.pkl\n" ] } ], "source": [ "rnn.fit(minibatches, epochs = 5, print_step=5)\n", "full_save(rnn)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4VGXexvHvLwkhhCoQegklFlCkhN4VEFcFO6AiKKI0EVF23XXd1911fdeyKCpFQHFlUUQUXxZURKRKDdJ7hyglgIB0As/7R4Y1soFJJDNnJnN/rivXzpw5k7k963jntOcx5xwiIiKXEuV1ABERCX0qCxER8UtlISIifqksRETEL5WFiIj4pbIQERG/VBYiIuKXykJERPxSWYiIiF8xXgfILSVLlnSJiYlexxARCStLly7d75xL8LdenimLxMREUlJSvI4hIhJWzGxHdtbTYSgREfFLZSEiIn6pLERExC+VhYiI+KWyEBERv1QWIiLil8pCRET8iviycM7x4ufr2Jp21OsoIiIhK+LLYtv+Y4xfvJP2Q+byxoxNnE4/53UkEZGQE/FlUTWhEF8/1ZJ2NUozePpGfvPGXJZsP+h1LBGRkBLxZQFQqnAcb91XlzEP1efE6bPcM2IBv/90FYdPnPE6mohISFBZZNL6qlJMH9iCns2r8NGSnbQZPJspK3/AOed1NBERT6ksLhAfG8Ozt9Rgcr9mlCkSR78PltHjnymk/njc62giIp5RWVzEteWLMqlPE567tQYLtx6g7eA5jJ67lfSzOgEuIpFHZXEJMdFR9GhWhekDW9KkWglemLqO24d9y+rvD3sdTUQkqFQW2VC+WAFGd0tm2P112XvkFB3emscLU9Zy7FS619FERIJCZZFNZsZvrivL1wNb0qVBJUbP20a71+bwzfq9XkcTEQk4lUUOFS2Qj7/dcR0TezUmPjaah99Loe+479h35KTX0UREAkZl8SslJxZnav/mPN3uSqav28uNg2czbtEOzp3TZbYikveoLC5DbEwU/W5I4ssnmnNtuaI8O2k19769gI17f/I6mohIrlJZ5IKqCYX4oGdDXrm7FpvTjnLLG3P5x1cbOHnmrNfRRERyhcoil5gZ9yRXZMbAltxWqxxvfrOZm4fMZf6W/V5HExG5bCqLXFaiUH4Gd6rN2B4NOHvOcd+oRQz6eAU/HjvtdTQRkV9NZREgzZMSmDagBb1bVWPSsu+5cfBsJi1L1ThTIhKWVBYBVCA2mt+1v5op/ZtRuUQ8T360ggffXcyOA8e8jiYikiMqiyC4ukwRJvZqwl871mTZzkO0e20Ow2dt4YzGmRKRMKGyCJLoKKNr40S+HtiS1leV4qUv13Pbm/NYtvNHr6OJiPilsgiyMkXjGNG1HiO71uPQ8TPcOXw+//N/q/nppCZaEpHQpbLwSLuaZfj6qZZ0a5zI+wt30HbwHKat2eN1LBGRLAW0LMysvZltMLPNZvZMFq/3MrNVZrbczOaZWQ3f8kQzO+FbvtzMRgQyp1cK5Y/h+Q41mdSnKVcUjOWxsUt59P0Udh8+4XU0EZFfsEBdymlm0cBGoC2QCiwBujjn1mZap4hz7ojvcQegj3OuvZklAlOcc9dm9/OSk5NdSkpKLv4TBNeZs+d4Z942Xv96IzFRUTzd7kq6Nk4kOsq8jiYieZiZLXXOJftbL5B7Fg2Azc65rc6508B4oGPmFc4XhU9BIGJvQsgXHUWvltX4akBL6lQqxvP/Xsudw+ezbvcR/28WEQmwQJZFeWBXpuepvmW/YGZ9zWwL8DLQP9NLVcxsmZnNNrPmWX2AmT1qZilmlpKWlpab2T1TqUQ87z/cgCGda5N68Di3vjmPv3+xnhOnNc6UiHgnkGWR1fGT/9pzcM4Ndc5VA34H/NG3eDdQyTlXBxgIfGBmRbJ470jnXLJzLjkhISEXo3vLzOhYuzwznmrJXXXLM2L2Ftq9Pps5G/NGIYpI+AlkWaQCFTM9rwD8cIn1xwO3AzjnTjnnDvgeLwW2AFcGKGfIKhYfy8t3X8+HPRuRLyqKB99dzIDxy9h/9JTX0UQkwgSyLJYASWZWxcxigc7A5MwrmFlSpqe3AJt8yxN8J8gxs6pAErA1gFlDWuNqJfj8ieb0vzGJqat202bwbCak7NI4UyISNAErC+dcOtAPmAasAyY459aY2V98Vz4B9DOzNWa2nIzDTd18y1sAK81sBTAR6OWcOxiorOEgLl80A9teyef9m5NUqhC/nbiSLqMWsjXtqNfRRCQCBOzS2WAL90tnc+LcOcdHKbt48fN1nEo/R7/W1enVshqxMbrHUkRyJhQunZUAiYoyujSoxIynWtKuRmkGT9/Ib96Yy5LtEb3zJSIBpLIIY6UKx/HWfXUZ070+J06f5Z4RC/j9p6s4fELjTIlI7lJZ5AGtry7F9IEt6Nm8Ch8t2UmbwbOZsvIHnQAXkVyjssgj4mNjePaWGkzu14wyReLo98EyevwzhdQfj3sdTUTyAJVFHnNt+aJM6tOEP95yDQu3HqDt4DmMnruVdE20JCKXQWWRB8VER/FI86p89WQLGlcrwQtT13H7sG9Z/f1hr6OJSJhSWeRhFa6I551uyQy9ry57j5yiw1vzeGHKWo6dSvc6moiEGZVFHmdm3FKrLF8PbEnnBpUYPW8b7V6bwzfr93odTUTCiMoiQhQtkI8X77iOj3s1Jj42moffS6HvuO/Yd+Sk19FEJAyoLCJM/cTiTO3fnKfaXsn0dXu5cfBsxi3awblzusxWRC5OZRGBYmOiePzGJL58ojk1yxXh2UmrufftBWzc+5PX0UQkRKksIljVhEJ82LMRr9xdi81pR7nljbn846sNnDyjiZZE5JdUFhHOzLgnuSIzBrbk1lrlePObzdw8ZC7zt+z3OpqIhBCVhQBQolB+XutUm7E9GnD2nOO+UYsY9PEKfjx22utoIhICVBbyC82TEpg2oAW9W1Xj02Xfc+Pg2UxalqpxpkQinMpC/kuB2Gh+1/5qpjzejErF43nyoxU8+O5idhw45nU0EfGIykIu6pqyRfikdxP+3KEmy3Yeot1rcxg+awtnNM6USMRRWcglRUcZ3ZokMn1gC1pdlcBLX67ntjfnsW73Ea+jiUgQqSwkW8oWLcDbXZN5u2s9Dh47zd3D5zN7Y5rXsUQkSFQWkiM31SzD5H7NqFSiIA+/t4Txi3d6HUlEgkBlITlWpmgcH/dqTLPqJXnm01W8Mm29rpYSyeNUFvKrFMofw+huyXRpUJGhM7cw4KPlnErXnd8ieVWM1wEkfOWLjuLFO66jYvF4Xv5yA7sPn2Rk13oUi4/1OpqI5DLtWchlMTP6tKrOkM61Wb7zEHcNn8+ug5r3WySvUVlIruhYuzxjezRg/9HT3DHsW1bsOuR1JBHJRSoLyTUNq5bgk95NKBAbTaeRC/hqzR6vI4lILlFZSK6qXqoQn/ZuylVlivDYv5by3rfbvI4kIrlAZSG5LqFwfsb3bESba0rz/L/X8tcpazUTn0iYU1lIQBSIjWbEA/Xo3iSRd+Zto8+47zSpkkgYU1lIwERHGc93qMmfbq3BtLV76DJqIQeOnvI6loj8CioLCbiHm1Vh+P31WPvDEe4cPp+taUe9jiQiOaSykKBof20Zxj/aiKMn07lz+HxSth/0OpKI5IDKQoKmTqUr+LRPE4rHx3Lf6EVMWfmD15FEJJtUFhJUlUsU5JPeTbi+QlH6fbCMEbO3aBBCkTCgspCgu6JgLGN7NOTWWmX5+xfr+eNnq0nX7HsiIU0DCYon4vJF80bnOlS4Ip4Rs7ew+/BJ3uxSh4L59a+kSCjSnoV4JirKeObmq/nbHdcya8M+Oo1cwL4jJ72OJSJZCGhZmFl7M9tgZpvN7JksXu9lZqvMbLmZzTOzGple+73vfRvM7KZA5hRv3d+wMu90q8/WtGPcMWw+G/f+5HUkEblAwMrCzKKBocDNQA2gS+Yy8PnAOXedc6428DIw2PfeGkBnoCbQHhjm+32SR7W+uhQTHmvMmbPnuGv4fOZv3u91JBHJJJB7Fg2Azc65rc6508B4oGPmFZxzRzI9LQicvyymIzDeOXfKObcN2Oz7fZKHXVu+KJP6NqVs0Ti6jVnMJ0tTvY4kIj6BLIvywK5Mz1N9y37BzPqa2RYy9iz65+S9kveUL1aAib2b0KBKcZ76eAWvf71Rl9aKhIBAloVlsey/vvXOuaHOuWrA74A/5uS9ZvaomaWYWUpaWtplhZXQUSQuH2O6N+CuuhV4/etNDJq4ktPpurRWxEuBLItUoGKm5xWAS92yOx64PSfvdc6NdM4lO+eSExISLjOuhJLYmChevacWA9okMXFpKg+9t5gjJ894HUskYgWyLJYASWZWxcxiyThhPTnzCmaWlOnpLcAm3+PJQGczy29mVYAkYHEAs0oIMjMGtLmSV++5nkVbD3L38Pl8f+iE17FEIlLAysI5lw70A6YB64AJzrk1ZvYXM+vgW62fma0xs+XAQKCb771rgAnAWuBLoK9zTpMhRKi761Xgnw83YPfhk9wx9FtWf3/Y60giEcfyysnD5ORkl5KS4nUMCaCNe3/ioTFL+PH4aYbeV5fWV5fyOpJI2DOzpc65ZH/r6Q5uCRtXli7MpD5NqJpQkEfeT2Hcoh1eRxKJGCoLCSulisTx0aONaZFUkmcnrebvX6zX/N4iQaCykLBTMH8Mox5M5v6GlRgxewtPfLRc83uLBJiG+JSwFBMdxQu3X0vF4vH8/Yv17Dl8gpFdk7miYKzX0UTyJO1ZSNgyM3q1rMZb99VhReph7ho+n50HjnsdSyRPUllI2Lu1VjnGPdKQg8dPc8ewb1m280evI4nkOSoLyRPqJxbn095NKJg/hi6jFvLl6j1eRxLJU1QWkmdUTSjEpD5NuKZsEXqPW8o787Z5HUkkz1BZSJ5SolB+PuzZiJtqlOGvU9by/OQ1nNWltSKXTWUheU5cvmiG3l+XHs2q8N787fT+11JOnNaltSKXQ2UheVJ0lPHcrTV4/rYaTF+3l86jFrL/6CmvY4mELZWF5Gndm1bh7QfqsWHPEe4Y9i1b0o56HUkkLKksJM9rV7MM4x9tzInTZ7lz2HwWbzvodSSRsKOykIhQu2IxJvVpSolCsTwwehGTV1xqHi4RuVC2ysLMxmZnmUgoq1g8nk97N6F2pWL0/3AZw2Zt1vzeItmU3T2LmpmfmFk0UC/344gEVrH4WMb2aEDH2uV4+csN/GHSatLPan5vEX8uOZCgmf0e+ANQwMyOnF8MnAZGBjibSEDkj4nmtXtrU+GKAgyduYUfDp1g6P11KZRf42qKXMwl9yycc//rnCsMvOKcK+L7KeycK+Gc+32QMorkuqgoY9BNV/O/d17HvM37uXfEAvYeOel1LJGQld3DUFPMrCCAmT1gZoPNrHIAc4kERZcGlXinWzI7Dhzj9qHfsn7PEf9vEolA2S2L4cBxM7se+C2wA3g/YKlEgqjVVaWY0Ksx55zjnuELmLdpv9eRREJOdssi3WVcNtIRGOKcGwIUDlwskeCqWa4ok/o0pfwVBeg+ZjETUnZ5HUkkpGS3LH7ynezuCkz1XQ2VL3CxRIKvXLECfNyrMY2rleC3E1cy+KsNurRWxCe7ZdEJOAU87JzbA5QHXglYKhGPFI7Lx7vd63NvcgXe+GYzT01Ywel0XVorkq2y8BXEOKComd0KnHTO6ZyF5En5oqN46a5aPNX2Sj5d9j3d3l3M4RNnvI4l4qns3sF9L7AYuAe4F1hkZncHMpiIl8yMx29M4rVO15Oy4yB3D59P6o+a31siV3YPQz0L1HfOdXPOPQg0AJ4LXCyR0HBHnQq8/3BD9h45yR3D5rMq9bDXkUQ8kd2yiHLO7cv0/EAO3isS1hpXK8EnvZsQGx3FvW8vYMa6vV5HEgm67P4H/0szm2Zm3c2sOzAV+DxwsURCS1Lpwkzq24TqpQrR8/0Uxi7c4XUkkaC6ZFmYWXUza+qcGwS8DdQCrgcWoLGhJMKUKhzHR4814oarS/HcZ6t58fN1nNP83hIh/O1ZvA78BOCc+9Q5N9A59yQZexWvBzqcSKiJj43h7a7JPNi4MiPnbOXxD5dx8ozm95a8z98wm4nOuZUXLnTOpZhZYkASiYS46Cjjzx1qUvGKeP72+Tr2HDnJqAeTKV4w1utoIgHjb88i7hKvFcjNICLhxMzo2aIqw+6vy6rvD3PX8Pls33/M61giAeOvLJaYWc8LF5pZD2BpYCKJhI/fXFeWD3s25NDx09w5fD5Ld/zodSSRgLBLjX1jZqWBSWRMdnS+HJKBWOAO353dISE5OdmlpKR4HUMi1Lb9x3hozGJ2Hz7J651qc/N1Zb2OJJItZrbUOZfsbz1/kx/tdc41Af4MbPf9/Nk51ziUikLEa1VKFuTTPk2pWa4IfT74jtFzt2oQQslTsjWPpHNuJjAzwFlEwlrxgrF80LMRAycs54Wp69h18Dh/uq0m0VHmdTSRy6a7sEVyUVy+aN7qUpdHW1Tlnwt28NjYpRw/ne51LJHLFtCyMLP2ZrbBzDab2TNZvD7QzNaa2Uozm5F5qlYzO2tmy30/kwOZUyQ3RUUZf/jNNfy1Y02+Wb+XziMXsk/ze0uYC1hZ+CZIGgrcDNQAuphZjQtWWwYkO+dqAROBlzO9dsI5V9v30yFQOUUCpWvjREZ2TWbT3qPc8I/ZDJ+1RTfwSdgK5J5FA2Czc26rc+40MJ6MaVn/wzk30zl3ftznhUCFAOYRCbo2NUoztX8zGlUtzktfrqfN4NlMXblbJ78l7ASyLMoDmScyTvUtu5gewBeZnseZWYqZLTSz2wMRUCQYqiYUYnS3+ox7pCGF8sfQ94PvuGfEAlbsOuR1NJFsC2RZZHUJSJZ/TpnZA2Tcv5F5qtZKvmt/7wNeN7NqWbzvUV+hpKSlpeVGZpGAaVq9JFP7N+fvd17H9gPH6Tj0W578aDm7D5/wOpqIX4Esi1SgYqbnFYAfLlzJzNqQMblSB+fcqfPLnXM/+P53KzALqHPhe51zI51zyc655ISEhNxNLxIA0VFG5waVmDWoFX1bV2Pqqt20fnUWg7/awLFTumpKQlcgy2IJkGRmVcwsFugM/OKqJjOrQ8bQ5x0yT65kZleYWX7f45JAU2BtALOKBFWh/DEMuulqvnmqJW1rlOGNbzbT+tVZfJyyS8OeS0gKWFk459KBfsA0YB0wwTm3xsz+Ymbnr256BSgEfHzBJbLXAClmtoKMmwH/7pxTWUieU+GKeN7sUodPejehXLECDJq4ktvemsfCrQe8jibyC5ccGyqcaGwoCXfOOSav+IGXvljPD4dPclPN0vz+5mtILFnQ62iSh+XK2FAiEjxmRsfa5fnm6VYMuukq5m3aT9vXZvO3qWs5fOKM1/EkwqksREJMXL5o+rauzsxBrbizTgVGz9tGq1dm8v6C7aSfPed1PIlQKguREFWqcBwv3V2LqY8355qyRfjT/62h/ZC5zFy/Tzf1SdCpLERCXI1yRRj3SENGPZjM2XOOh95bwoPvLmbDnp+8jiYRRGUhEgbMjLY1SjNtQAueu7UGK3Yd4uYhc3h20ir2Hz3l/xeIXCaVhUgYiY2JokezKswe1JoHGyfy0ZJdtH5lFiNmb+FUugYplMBRWYiEoSsKxvJ8h5pMe7IFDasW5+9fZAxS+PkqDVIogaGyEAlj1XyDFP6rR0MKxsbQZ9x33Pu2BimU3KeyEMkDmiX9PEjhtv0ZgxQO1CCFkotUFiJ5ROZBCvu0qsaU84MUTt+oqV3lsqksRPKYQvlj+G37q5kxsCVtrinNGzM20frVWUxcmqpBCuVXU1mI5FEVi8fz1n11+aR3E8oWLcDTH6+gw9B5LNIghfIrqCxE8rh6la/g095NGNK5NgePnqbTyIX0GruUHQeOeR1NwojKQiQCREX9PEjh0+2uZM6mNNoM1iCFkn0qC5EIEpcvmn43JDHr6Z8HKWz96izGLtAghXJpKguRCFSqSMYghVMeb8ZVpQvz3PlBCjfs8/9miUgqC5EIVrNcUT7o2ZCRXeuRfvYcD43JGKRw414NUii/pLIQiXBmRruaZfjqyZY8d2sNlu/8kfavZwxSeECDFIqPykJEgP8epHD8kl20emUWb2uQQkFlISIX+M8ghQNa0KBKcf73i/W0HTyHLzRIYURTWYhIlqqXKsQ73TMGKYyPjab3uO/o9PZCVqZqkMJIpLIQkUs6P0jhi3dcx9b9R+nw1rcMnKBBCiONykJE/IqOMu5rWImZT7eid6tqTFmZMUjhaxqkMGKoLEQk2wrH5eN3mQYpHOIbpPATDVKY56ksRCTHfh6ksDFlihbgqY9X0HHotyzedtDraBIgKgsR+dXqVS7OJN8ghQeOnuLetxfQ+19L2XnguNfRJJfFeB1ARMLb+UEK29Uow+i5Wxk+ewsz1u2je9NE+t1QnSJx+byOKLlAexYikisKxEbz+I1JzHy6FR1rl2PU3K20emUWYxfu0CCFeYDKQkRyVekicbxyz/X8u18zrixdiOc+W83NQ+YyS4MUhjWVhYgExLXli/Jhz0a83bUeZ86eo/uYJXR7dzGbNEhhWFJZiEjAmBk3+QYp/OMt17Bs54+0HzKX5z5brUEKw4zKQkQCLjYmikeaV2X2oNZ0bVSZDxbvpNWrsxg5R4MUhguVhYgETeZBCusnFufFzzVIYbhQWYhI0FUvVYh3u9dnbI8GFMjnG6Rw5EJWpR72OppchMpCRDzTPCmBqf2bZQxSmHaU296ax8AJy9lz+KTX0eQClld2/ZKTk11KSorXMUTkV/rp5BmGzdrCO/O2YUD3pon0blmNYvGxXkfL08xsqXMu2e96KgsRCSW7Dh7n9a838emyVArlj6FXy2o81DSR+FgNOBEI2S2LgB6GMrP2ZrbBzDab2TNZvD7QzNaa2Uozm2FmlTO91s3MNvl+ugUyp4iEjorF4/nHvdfz5RMtaFilBK9M20BL353gZ3QnuGcCtmdhZtHARqAtkAosAbo459ZmWqc1sMg5d9zMegOtnHOdzKw4kAIkAw5YCtRzzv14sc/TnoVI3rR0x0Fe+mIDi7cfpHKJeAa2vZLbapUjKsq8jpYnhMKeRQNgs3Nuq3PuNDAe6Jh5BefcTOfc+eEpFwIVfI9vAqY75w76CmI60D6AWUUkRNWrXJyPHmvEmIfqEx8bwxPjl3PLm/OYuWGfLrcNokCWRXlgV6bnqb5lF9MD+OJXvldE8jAzo/VVpZj6eDOGdK7NsVPpPDRmCZ1GLmTpDs2hEQyBLIus9hGz/DPAzB4g45DTKzl5r5k9amYpZpaSlpb2q4OKSHg4Pxz61wNb8tfbr2Xb/mPcNXwBj/wzhQ17NOZUIAWyLFKBipmeVwB+uHAlM2sDPAt0cM6dysl7nXMjnXPJzrnkhISEXAsuIqEtNiaKro0qM3tQKwbddBWLth2g/ZA5DJywnF0HNfFSIATyBHcMGSe4bwS+J+ME933OuTWZ1qkDTATaO+c2ZVpenIyT2nV9i74j4wT3Rfc3dYJbJHIdOn6a4bO38N632znnHPc3rEy/G6pTslB+r6OFvJC4z8LMfgO8DkQD7zrn/mZmfwFSnHOTzexr4Dpgt+8tO51zHXzvfRj4g2/535xzYy71WSoLEdlz+CRDZmxiQsou8vsGL+zZvAqFNVvfRYVEWQSTykJEztuadpR/TN/I1JW7uSI+H31bV+eBRpWJyxftdbSQo7IQkYi3KvUwL09bz9xN+ylXNI4Bba/kzjrliYnWsHjnhcJ9FiIinrquQlHG9mjIB480JKFIHL+duJL2Q+by5eo9ukcjh1QWIpLnNaleks/6NGHEA/VwztHrX0u5fdh85m/Z73W0sKGyEJGIYGa0v7YM0wa04OW7a5F25CT3jVpE13cWsfp7zaPhj85ZiEhEOnnmLP9auIOhMzfz4/Ez3FqrLE+1u4oqJQt6HS2odIJbRCQbjpw8w+g5Wxk9bxun0s/RqX5FnrgxidJF4ryOFhQqCxGRHEj76RRDZ25m3KIdRJnxUNMq9G5ZjaLxefseDZWFiMivsOvgcV6bvpFJy7+PiMmXVBYiIpdh/Z4jvDptA1+v20dC4fz0vzGJzvUrki+P3aOh+yxERC7D1WWKMLpbfSb2akxiiXie+2w1bQbPZvKKHzh3Lm/8kZ0TKgsRkUtITizOhMcaM6Z7fQrki6b/h8u49c15zIqwyZdUFiIifpgZra8uxef9mzOkc22Onkqn+5gldB65kKU7Ljrbc56ishARyaZfTL7UsSZb0o5x1/D59Hw/hY178/bkSzrBLSLyKx0/nc6Yb7czYtYWjp5O5846FRjQJomKxeO9jpZtuhpKRCRIfjx2mhGzt/De/O04B/c3qkTf1uEx+ZLKQkQkyHYfPsEbMzYxISWVON/kS4+E+ORLKgsREY9sSTvK4K82MnXVbooXjKVv6+rc37BSSE6+pLIQEfHYytRDvDJtA3M37ad8sQIMaJPEnXUrEB1lXkf7D92UJyLisVoVijG2R0PGPdKQkoViGTRxJe1fn8O0NeE3+ZLKQkQkwJpWL8lnfZsy4oG6nHWOx8Yu5Y5h81mw5YDX0bJNZSEiEgQZky+V5asBLXj5rlrsPXKSLqMW8uC7i8Ni8iWdsxAR8cDJM2cZu2AHQ2dt5pCHky/pBLeISBg4cvIMo+ZsZfTcbZw+G/zJl1QWIiJhJO2nU7z1zSY+WLyT6Cije5PgTL6kshARCUM7Dxznta838tny7ymcP4ZerarxUJMqFIgNzD0aKgsRkTC2bnfG5Esz1u+jlG/ypU4BmHxJ91mIiISxa8oW4Z3u9fm4V2MqFY/nj5+tpq2Hky+pLEREQlj9xOJ83Ksx73ZPJs7DyZdUFiIiIc7MuOHq0nzevzmvd6rNT6fOBH3yJZWFiEiYiIoybq9TnhkDW/GXTJMv9R33XcD3MmIC+ttFRCTXxcZE8WDjRO6qW4Ex327jxJmzmAV2cEKVhYhImCqYP4Z+NyQF5bN0GEpERPxSWYiIiF8qCxER8UtlISIifqksRETEL5WFiIj4pbIQERG/VBYiIuJXnhmi3MzSgB2X8StKAvvbbz14AAAF40lEQVRzKU5uUq6cUa6cUa6cyYu5KjvnEvytlGfK4nKZWUp2xnQPNuXKGeXKGeXKmUjOpcNQIiLil8pCRET8Uln8bKTXAS5CuXJGuXJGuXImYnPpnIWIiPilPQsREfErosrCzNqb2QYz22xmz2Txen4z+8j3+iIzSwyRXN3NLM3Mlvt+HglSrnfNbJ+Zrb7I62Zmb/hyrzSzuiGSq5WZHc60vf4UpFwVzWymma0zszVm9kQW6wR9m2UzV9C3mZnFmdliM1vhy/XnLNYJ+ncym7k8+U76PjvazJaZ2ZQsXgvc9nLORcQPEA1sAaoCscAKoMYF6/QBRvgedwY+CpFc3YG3PNhmLYC6wOqLvP4b4AvAgEbAohDJ1QqY4sH2KgvU9T0uDGzM4v/LoG+zbOYK+jbzbYNCvsf5gEVAowvW8eI7mZ1cnnwnfZ89EPggq/+/Arm9ImnPogGw2Tm31Tl3GhgPdLxgnY7AP32PJwI3WqDnKsxeLk845+YABy+xSkfgfZdhIVDMzMqGQC5POOd2O+e+8z3+CVgHlL9gtaBvs2zmCjrfNjjqe5rP93PhSdSgfyezmcsTZlYBuAUYfZFVAra9IqksygO7Mj1P5b+/MP9ZxzmXDhwGSoRALoC7fIctJppZxQBnyq7sZvdCY99hhC/MrGawP9y3+1+HjL9KM/N0m10iF3iwzXyHVJYD+4DpzrmLbq8gfiezkwu8+U6+DvwWOHeR1wO2vSKpLLJq1wv/WsjOOrktO5/5byDROVcL+Jqf/3LwmhfbKzu+I2MIg+uBN4HPgvnhZlYI+AQY4Jw7cuHLWbwlKNvMTy5Ptplz7qxzrjZQAWhgZtdesIon2ysbuYL+nTSzW4F9zrmll1oti2W5sr0iqSxSgcztXwH44WLrmFkMUJTAH+7wm8s5d8A5d8r3dBRQL8CZsis72zTonHNHzh9GcM59DuQzs5LB+Gwzy0fGf5DHOec+zWIVT7aZv1xebjPfZx4CZgHtL3jJi++k31wefSebAh3MbDsZh6tvMLN/XbBOwLZXJJXFEiDJzKqYWSwZJ38mX7DOZKCb7/HdwDfOd6bIy1wXHNPuQMYx51AwGXjQd4VPI+Cwc26316HMrMz547Rm1oCMf88PBOFzDXgHWOecG3yR1YK+zbKTy4ttZmYJZlbM97gA0AZYf8FqQf9OZieXF99J59zvnXMVnHOJZPx34hvn3AMXrBaw7RWTG78kHDjn0s2sHzCNjCuQ3nXOrTGzvwApzrnJZHyhxprZZjLauHOI5OpvZh2AdF+u7oHOBWBmH5JxlUxJM0sF/oeMk30450YAn5Nxdc9m4DjwUIjkuhvobWbpwAmgcxBKHzL+8usKrPId7wb4A1ApUzYvtll2cnmxzcoC/zSzaDLKaYJzborX38ls5vLkO5mVYG0v3cEtIiJ+RdJhKBER+ZVUFiIi4pfKQkRE/FJZiIiIXyoLERHxS2UhkgNmdjbTSKPLLYtRgi/jdyfaRUbSFfFaxNxnIZJLTviGgRCJKNqzEMkFZrbdzF7yzYOw2Myq+5ZXNrMZvgHnZphZJd/y0mY2yTdw3woza+L7VdFmNsoy5lH4yncHsYjnVBYiOVPggsNQnTK9dsQ51wB4i4zRQfE9ft834Nw44A3f8jeA2b6B++oCa3zLk4ChzrmawCHgrgD/84hki+7gFskBMzvqnCuUxfLtwA3Oua2+Qfv2OOdKmNl+oKxz7oxv+W7nXEkzSwMqZBqM7vzw4dOdc0m+578D8jnnXgj8P5nIpWnPQiT3uIs8vtg6WTmV6fFZdF5RQoTKQiT3dMr0vwt8j+fz82Bu9wPzfI9nAL3hPxPtFAlWSJFfQ3+1iORMgUwjtwJ86Zw7f/lsfjNbRMYfYV18y/oD75rZICCNn0eZfQIYaWY9yNiD6A14Pry7yMXonIVILvCds0h2zu33OotIIOgwlIiI+KU9CxER8Ut7FiIi4pfKQkRE/FJZiIiIXyoLERHxS2UhIiJ+qSxERMSv/we7vQ+aZEsu8gAAAABJRU5ErkJggg==\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#Plot the loss\n", "def plot_loss_list(loss_list= rnn.loss_list):\n", " plt.plot(rnn.loss_list)\n", " plt.xlabel(\"Epoch\")\n", " plt.ylabel(\"Cost\")\n", " plt.show()\n", "\n", "plot_loss_list()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "#folder = get_rnn_folder(ncells = ncells, cell_type = \"lstm\", activation = \"leaky_relu\")\n", "#rnn, data = full_load(folder)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def rnn_test(rnn, test_input= test_input, test_target= test_target):\n", " \n", " #Here I predict based on my test set\n", " test_pred = rnn.predict(test_input)\n", " \n", " #Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n", " #scaler_inv(test_input, scalerfunc = func)[0,:,:]\n", " diff = scaler_inv(test_pred, scalerfunc = func)-scaler_inv(test_target, scalerfunc = func )\n", " print(diff[random.randint(0,test_pred.shape[0]),:,:])\n", " \n", " #Here I evaluate my model on the test set based on mean_squared_error\n", " loss = rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})\n", " print(\"Loss on test set:\", loss)\n", " \n", " return test_pred, loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-12.40089102 11.3887205 14.85893158]\n", " [ -2.58961379 15.07054156 13.61432775]\n", " [ -0.97645625 9.64583194 6.30191887]\n", " [ -0.18267044 -0.89442267 -5.31442916]\n", " [ 1.98853966 9.28189136 -5.76572485]\n", " [ -0.40098165 2.10118956 -3.39426367]\n", " [ 2.50788286 -4.49019351 2.29174324]]\n", "Loss on test set: 0.14005564\n" ] } ], "source": [ "test_pred, test_loss = rnn_test(rnn=rnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }