{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib as mpl\n", "import random\n", "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops\n", "from sklearn import preprocessing" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#import data as array\n", "# 8 hits with x,y,z\n", "\n", "testset = pd.read_pickle('matched_8hittracks.pkl')\n", "#print(testset)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#Check testset with arbitrary particle\n", "\n", "tset = np.array(testset)\n", "tset = tset.astype('float32')\n", "#print(tset.shape)\n", "#for i in range(8):\n", " #print(tset[1,3*i:(3*i+3)])\n", "#print(tset[0,:])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n", "\n", "def reshapor(arr_orig):\n", " timesteps = int(arr_orig.shape[1]/3)\n", " number_examples = int(arr_orig.shape[0])\n", " arr = np.zeros((number_examples, timesteps, 3))\n", " \n", " for i in range(number_examples):\n", " for t in range(timesteps):\n", " arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n", " \n", " return arr\n", "\n", "def reshapor_inv(array_shaped):\n", " timesteps = int(array_shaped.shape[1])\n", " num_examples = int(array_shaped.shape[0])\n", " arr = np.zeros((num_examples, timesteps*3))\n", " \n", " for i in range(num_examples):\n", " for t in range(timesteps):\n", " arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "### create the training set and the test set###\n", "\n", "def create_random_sets(dataset, train_to_total_ratio):\n", " #shuffle the dataset\n", " num_examples = dataset.shape[0]\n", " p = np.random.permutation(num_examples)\n", " dataset = dataset[p,:]\n", " \n", " #evaluate siye of training and test set and initialize them\n", " train_set_size = np.int(num_examples*train_to_total_ratio)\n", " test_set_size = num_examples - train_set_size\n", " \n", " train_set = np.zeros((train_set_size, dataset.shape[1]))\n", " test_set = np.zeros((test_set_size, dataset.shape[1]))\n", " \n", "\n", " #fill train and test sets\n", " for i in range(num_examples):\n", " if train_set_size > i:\n", " train_set[i,:] += dataset[i,:]\n", " else:\n", " test_set[i - train_set_size,:] += dataset[i,:]\n", " \n", " return train_set, test_set\n", " " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "train_set, test_set = create_random_sets(tset, 0.99)\n", "\n", "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n", "#print(test_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "#Normalize the data advanced version with scikit learn\n", "\n", "#set the transormation based on training set\n", "def set_min_max_scaler(arr, feature_range= (-1,1)):\n", " min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n", " if len(arr.shape) == 3:\n", " arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr))) \n", " else:\n", " arr = min_max_scalor.fit_transform(arr)\n", " return min_max_scalor\n", "\n", "min_max_scalor = set_min_max_scaler(train_set)\n", "\n", "\n", "#transform data\n", "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(min_max_scalor.transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = min_max_scalor.transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr\n", " \n", "#inverse transformation\n", "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = min_max_scalor.inverse_transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = min_max_scalor.nverse_transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "#Normalize the data advanced version with scikit learn - Standard scaler\n", "\n", "#set the transormation based on training set\n", "def set_std_scaler(arr):\n", " std_scalor = preprocessing.StandardScaler()\n", " if len(arr.shape) == 3:\n", " arr = reshapor(std_scalor.fit(reshapor_inv(arr))) \n", " else:\n", " arr = std_scalor.fit(arr)\n", " return std_scalor\n", "\n", "std_scalor = set_std_scaler(train_set)\n", "\n", "#transform data\n", "def std_scaler(arr, std_scalor= std_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(std_scalor.transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = std_scalor.transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr\n", " \n", "#inverse transformation\n", "def std_scaler_inv(arr, std_scalor= std_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " if arr.shape[1] == 8:\n", " arr = reshapor(std_scalor.inverse_transform(reshapor_inv(arr)))\n", " else: \n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr = reshapor_inv(arr)\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n", " arr = reshapor(arr)\n", " \n", " else:\n", " if arr.shape[1] == 24:\n", " arr = std_scalor.inverse_transform(arr)\n", " else:\n", " arr_ = np.zeros((arr.shape[0],24))\n", " arr_[:,:arr.shape[1]] += arr\n", " arr = std_scalor.inverse_transform(arr_)[:,:arr.shape[1]]\n", " \n", " return arr\n", "\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "#reshape the data\n", "\n", "train_set = reshapor(train_set)\n", "test_set = reshapor(test_set)\n", "\n", "#print(train_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "#Scale data either with MinMax scaler or with Standard scaler\n", "#Return scalor if fit = True and and scaled array otherwise\n", "\n", "def scaler(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n", " \n", " if scalerfunc == \"std\":\n", " arr = std_scaler(arr, std_scalor= std_scalor)\n", " return arr\n", " \n", " elif scalerfunc == \"minmax\":\n", " arr = min_max_scaler(arr, min_max_scalor= min_max_scalor)\n", " return arr\n", " \n", " else:\n", " raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n", "\n", "def scaler_inv(arr, std_scalor= std_scalor, min_max_scalor= min_max_scalor, scalerfunc= \"std\"):\n", "\n", " if scalerfunc == \"std\":\n", " arr = std_scaler_inv(arr, std_scalor= std_scalor)\n", " return arr\n", " \n", " elif scalerfunc == \"minmax\":\n", " arr = min_max_scaler_inv(arr, min_max_scalor= std_scalor)\n", " return arr\n", " \n", " else:\n", " raise ValueError(\"Uknown scaler chosen: {}\".format(scalerfunc))\n" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.02109399 0.0394468 -0.01875739]\n", " [-0.0158357 0.02916325 -0.02021501]\n", " [-0.00411211 0.01346626 -0.01817778]\n", " [-0.00314466 0.01169437 -0.00971874]\n", " [ 0.00827457 -0.00905463 -0.00903793]\n", " [ 0.00906477 -0.01100179 -0.00610165]\n", " [ 0.01623521 -0.02745446 0.00036546]\n", " [ 0.01879028 -0.03098714 -0.0009012 ]]\n" ] } ], "source": [ "#scale the data\n", "\n", "func = \"minmax\"\n", "\n", "train_set = scaler(train_set, scalerfunc = func)\n", "test_set = scaler(test_set, scalerfunc = func)\n", "\n", "print(train_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "###create random mini_batches###\n", "\n", "\n", "def unison_shuffled_copies(a, b):\n", " assert a.shape[0] == b.shape[0]\n", " p = np.random.permutation(a.shape[0])\n", " return a[p,:,:], b[p,:,:]\n", "\n", "def random_mini_batches(inputt, target, minibatch_size = 500):\n", " \n", " num_examples = inputt.shape[0]\n", " \n", " \n", " #Number of complete batches\n", " \n", " number_of_batches = int(num_examples/minibatch_size)\n", " minibatches = []\n", " \n", " #shuffle particles\n", " _i, _t = unison_shuffled_copies(inputt, target)\n", " #print(_t.shape)\n", " \n", " \n", " for i in range(number_of_batches):\n", " \n", " minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatches.append((minibatch_train, minibatch_true))\n", " \n", " \n", " minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n", " \n", " \n", " return minibatches\n", " " ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "#Create random minibatches of train and test set with input and target array\n", "\n", "\n", "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n", "#_train, _target = minibatches[0]\n", "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n", "#print(train[0,:,:], target[0,:,:])" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "#minibatches = random_mini_batches(inputt_train, target_train)\n", "\n", "\n", "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n", "\n", "#print(len(minibatches))\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "class RNNPlacePrediction():\n", " \n", " \n", " def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\"):\n", " \n", " self.nsteps = time_steps\n", " self.future_steps = future_steps\n", " self.ninputs = ninputs\n", " self.ncells = ncells\n", " self.num_output = num_output\n", " self._ = cell_type #later used to create folder name\n", " self.__ = activation #later used to create folder name\n", " \n", " #### The input is of shape (num_examples, time_steps, ninputs)\n", " #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n", " self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", " self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", "\n", " \n", " #Check if activation function valid and set activation\n", " if activation==\"relu\":\n", " self.activation = tf.nn.relu\n", " \n", " elif activation==\"tanh\":\n", " self.activation = tf.nn.tanh\n", " \n", " elif activation==\"leaky_relu\":\n", " self.activation = tf.nn.leaky_relu\n", " \n", " elif activation==\"elu\":\n", " self.activation = tf.nn.elu\n", " \n", " else:\n", " raise ValueError(\"Wrong rnn avtivation function: {}\".format(activation))\n", " \n", " \n", " \n", " #Check if cell type valid and set cell_type\n", " if cell_type==\"basic_rnn\":\n", " self.cell_type = tf.contrib.rnn.BasicRNNCell\n", " \n", " elif cell_type==\"lstm\":\n", " self.cell_type = tf.contrib.rnn.BasicLSTMCell\n", " \n", " elif cell_type==\"GRU\":\n", " self.cell_type = tf.contrib.rnn.GRUCell\n", " \n", " else:\n", " raise ValueError(\"Wrong rnn cell type: {}\".format(cell_type))\n", " \n", " \n", " #Check Input of ncells \n", " if (type(self.ncells) == int):\n", " self.ncells = [self.ncells]\n", " \n", " if (type(self.ncells) != list):\n", " raise ValueError(\"Wrong type of Input for ncells\")\n", " \n", " for _ in range(len(self.ncells)):\n", " if type(self.ncells[_]) != int:\n", " raise ValueError(\"Wrong type of Input for ncells\")\n", " \n", " self.activationlist = []\n", " for _ in range(len(self.ncells)-1):\n", " self.activationlist.append(self.activation)\n", " self.activationlist.append(tf.nn.tanh)\n", " \n", " self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n", " for layer in range(len(self.ncells))])\n", " \n", " \n", " #### I now define the output\n", " self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n", " \n", " \n", " \n", " \n", " \n", " self.sess = tf.Session()\n", " \n", " def set_cost_and_functions(self, LR=0.001):\n", " #### I define here the function that unrolls the RNN cell\n", " self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n", " #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n", " self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output)) \n", " \n", " #### the rest proceed as usual\n", " self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n", " #### Variable initializer\n", " self.init = tf.global_variables_initializer()\n", " self.saver = tf.train.Saver()\n", " self.sess.run(self.init)\n", " \n", " \n", " def save(self, filename=\"./rnn_model/rnn_basic\"):\n", " self.saver.save(self.sess, filename)\n", " \n", " \n", " def load(self, filename=\"./rnn_model/rnn_basic\"):\n", " self.saver.restore(self.sess, filename)\n", " \n", " \n", " \n", " def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n", " self.loss_list = []\n", " patience_cnt = 0\n", " epoche_save = 0\n", " \n", " folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n", " \n", " for iep in range(epochs):\n", " loss = 0\n", " \n", " batches = len(minibatches)\n", " #Here I iterate over the batches\n", " for batch in range(batches):\n", " #### Here I train the RNNcell\n", " #### The X is the time series, the Y is shifted by 1 time step\n", " train, target = minibatches[batch]\n", " self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n", " \n", " \n", " loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n", " \n", " #Normalize loss over number of batches and scale it back before normaliziation\n", " loss /= batches\n", " self.loss_list.append(loss)\n", " \n", " #print(loss)\n", " \n", " #Here I create the checkpoint if the perfomance is better\n", " if iep > 1 and iep%checkpoint == 0 and self.loss_list[iep] < self.loss_list[epoche_save]:\n", " #print(\"Checkpoint created at epoch: \", iep)\n", " self.save(folder)\n", " epoche_save = iep\n", " \n", " #early stopping with patience\n", " if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 2/10**7:\n", " patience_cnt += 1\n", " #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n", " \n", " if patience_cnt + 1 > patience:\n", " print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n", " print(\"Cost: \",loss)\n", " break\n", " \n", " #Note that the loss here is multiplied with 1000 for easier reading\n", " if iep%print_step==0:\n", " print(\"Epoch number \",iep)\n", " print(\"Cost: \",loss*10**6, \"e-6\")\n", " print(\"Patience: \",patience_cnt, \"/\", patience)\n", " print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n", " \n", " #Set model back to the last checkpoint if performance was better\n", " if self.loss_list[epoche_save] < self.loss_list[iep]:\n", " self.load(folder)\n", " print(\"\\n\", \"Last checkpoint at epoch \", epoche_save, \" loaded\")\n", " print(\"Performance at last checkpoint is \" ,self.loss_list[iep] - self.loss_list[epoche_save], \" better\" )\n", " \n", " \n", " \n", " def predict(self, x):\n", " return self.sess.run(self.output, feed_dict={self.X:x})\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "timesteps = 7\n", "future_steps = 1\n", "\n", "ninputs = 3\n", "\n", "#ncells as int or list of int\n", "ncells = [50, 40, 30, 20, 10]\n", "\n", "num_output = 3" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use the retry module or similar alternatives.\n" ] } ], "source": [ "tf.reset_default_graph()\n", "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n", " ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "rnn.set_cost_and_functions()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch number 0\n", "Cost: 10.041199672838395 e-3\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 0 \n", "\n", "Epoch number 5\n", "Cost: 0.14646259134021053 e-3\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 5 \n", "\n", "Epoch number 10\n", "Cost: 0.14038292159811852 e-3\n", "Patience: 5 / 200\n", "Last checkpoint at: Epoch 10 \n", "\n", "Epoch number 15\n", "Cost: 0.13558934176429868 e-3\n", "Patience: 10 / 200\n", "Last checkpoint at: Epoch 15 \n", "\n", "Epoch number 20\n", "Cost: 0.12642440127278182 e-3\n", "Patience: 14 / 200\n", "Last checkpoint at: Epoch 20 \n", "\n", "Epoch number 25\n", "Cost: 0.1116786241912818 e-3\n", "Patience: 16 / 200\n", "Last checkpoint at: Epoch 25 \n", "\n", "Epoch number 30\n", "Cost: 0.10637743763893129 e-3\n", "Patience: 20 / 200\n", "Last checkpoint at: Epoch 30 \n", "\n", "Epoch number 35\n", "Cost: 0.10180761176904544 e-3\n", "Patience: 21 / 200\n", "Last checkpoint at: Epoch 35 \n", "\n", "Epoch number 40\n", "Cost: 0.10329305703325713 e-3\n", "Patience: 25 / 200\n", "Last checkpoint at: Epoch 35 \n", "\n", "Epoch number 45\n", "Cost: 0.09893714772299567 e-3\n", "Patience: 26 / 200\n", "Last checkpoint at: Epoch 45 \n", "\n", "Epoch number 50\n", "Cost: 0.09669851916693548 e-3\n", "Patience: 28 / 200\n", "Last checkpoint at: Epoch 50 \n", "\n", "Epoch number 55\n", "Cost: 0.09474931256919901 e-3\n", "Patience: 30 / 200\n", "Last checkpoint at: Epoch 55 \n", "\n", "Epoch number 60\n", "Cost: 0.09272654031210163 e-3\n", "Patience: 33 / 200\n", "Last checkpoint at: Epoch 60 \n", "\n", "Epoch number 65\n", "Cost: 0.09420149952812279 e-3\n", "Patience: 35 / 200\n", "Last checkpoint at: Epoch 60 \n", "\n", "Epoch number 70\n", "Cost: 0.09541216964630331 e-3\n", "Patience: 36 / 200\n", "Last checkpoint at: Epoch 60 \n", "\n", "Epoch number 75\n", "Cost: 0.09047800716522962 e-3\n", "Patience: 39 / 200\n", "Last checkpoint at: Epoch 75 \n", "\n", "Epoch number 80\n", "Cost: 0.09089725666699257 e-3\n", "Patience: 39 / 200\n", "Last checkpoint at: Epoch 75 \n", "\n", "Epoch number 85\n", "Cost: 0.08590354093726962 e-3\n", "Patience: 40 / 200\n", "Last checkpoint at: Epoch 85 \n", "\n", "Epoch number 90\n", "Cost: 0.08550771444595041 e-3\n", "Patience: 41 / 200\n", "Last checkpoint at: Epoch 90 \n", "\n", "Epoch number 95\n", "Cost: 0.08262849370750816 e-3\n", "Patience: 42 / 200\n", "Last checkpoint at: Epoch 95 \n", "\n", "Epoch number 100\n", "Cost: 0.08081882078066825 e-3\n", "Patience: 45 / 200\n", "Last checkpoint at: Epoch 100 \n", "\n", "Epoch number 105\n", "Cost: 0.08332692371542624 e-3\n", "Patience: 48 / 200\n", "Last checkpoint at: Epoch 100 \n", "\n", "Epoch number 110\n", "Cost: 0.0850605532871262 e-3\n", "Patience: 50 / 200\n", "Last checkpoint at: Epoch 100 \n", "\n", "Epoch number 115\n", "Cost: 0.08140491588571248 e-3\n", "Patience: 50 / 200\n", "Last checkpoint at: Epoch 100 \n", "\n", "Epoch number 120\n", "Cost: 0.0823190781916987 e-3\n", "Patience: 52 / 200\n", "Last checkpoint at: Epoch 100 \n", "\n", "Epoch number 125\n", "Cost: 0.0766505290038309 e-3\n", "Patience: 55 / 200\n", "Last checkpoint at: Epoch 125 \n", "\n", "Epoch number 130\n", "Cost: 0.07502320984210027 e-3\n", "Patience: 56 / 200\n", "Last checkpoint at: Epoch 130 \n", "\n", "Epoch number 135\n", "Cost: 0.0758755330102855 e-3\n", "Patience: 57 / 200\n", "Last checkpoint at: Epoch 130 \n", "\n", "Epoch number 140\n", "Cost: 0.0731801113207884 e-3\n", "Patience: 58 / 200\n", "Last checkpoint at: Epoch 140 \n", "\n", "Epoch number 145\n", "Cost: 0.0745931863499944 e-3\n", "Patience: 60 / 200\n", "Last checkpoint at: Epoch 140 \n", "\n", "Epoch number 150\n", "Cost: 0.05597170093096793 e-3\n", "Patience: 60 / 200\n", "Last checkpoint at: Epoch 150 \n", "\n", "Epoch number 155\n", "Cost: 0.0448569248584 e-3\n", "Patience: 61 / 200\n", "Last checkpoint at: Epoch 155 \n", "\n", "Epoch number 160\n", "Cost: 0.0377340710404864 e-3\n", "Patience: 63 / 200\n", "Last checkpoint at: Epoch 160 \n", "\n", "Epoch number 165\n", "Cost: 0.03712705128759324 e-3\n", "Patience: 64 / 200\n", "Last checkpoint at: Epoch 165 \n", "\n", "Epoch number 170\n", "Cost: 0.037240219558527236 e-3\n", "Patience: 67 / 200\n", "Last checkpoint at: Epoch 165 \n", "\n", "Epoch number 175\n", "Cost: 0.041023939860330774 e-3\n", "Patience: 67 / 200\n", "Last checkpoint at: Epoch 165 \n", "\n", "Epoch number 180\n", "Cost: 0.03179026030108056 e-3\n", "Patience: 69 / 200\n", "Last checkpoint at: Epoch 180 \n", "\n", "Epoch number 185\n", "Cost: 0.037844479401370486 e-3\n", "Patience: 71 / 200\n", "Last checkpoint at: Epoch 180 \n", "\n", "Epoch number 190\n", "Cost: 0.02333719505181665 e-3\n", "Patience: 72 / 200\n", "Last checkpoint at: Epoch 190 \n", "\n", "Epoch number 195\n", "Cost: 0.02318771433412157 e-3\n", "Patience: 77 / 200\n", "Last checkpoint at: Epoch 195 \n", "\n", "Epoch number 200\n", "Cost: 0.025808127712151234 e-3\n", "Patience: 79 / 200\n", "Last checkpoint at: Epoch 195 \n", "\n", "Epoch number 205\n", "Cost: 0.021487966265301518 e-3\n", "Patience: 82 / 200\n", "Last checkpoint at: Epoch 205 \n", "\n", "Epoch number 210\n", "Cost: 0.020788879447401144 e-3\n", "Patience: 85 / 200\n", "Last checkpoint at: Epoch 210 \n", "\n", "Epoch number 215\n", "Cost: 0.02056433168810203 e-3\n", "Patience: 85 / 200\n", "Last checkpoint at: Epoch 215 \n", "\n", "Epoch number 220\n", "Cost: 0.016506806942027438 e-3\n", "Patience: 89 / 200\n", "Last checkpoint at: Epoch 220 \n", "\n", "Epoch number 225\n", "Cost: 0.020985714496767265 e-3\n", "Patience: 91 / 200\n", "Last checkpoint at: Epoch 220 \n", "\n", "Epoch number 230\n", "Cost: 0.011625469693520225 e-3\n", "Patience: 94 / 200\n", "Last checkpoint at: Epoch 230 \n", "\n", "Epoch number 235\n", "Cost: 0.013143771576188614 e-3\n", "Patience: 98 / 200\n", "Last checkpoint at: Epoch 230 \n", "\n", "Epoch number 240\n", "Cost: 0.017444268354317522 e-3\n", "Patience: 100 / 200\n", "Last checkpoint at: Epoch 230 \n", "\n", "Epoch number 245\n", "Cost: 0.013935790078942367 e-3\n", "Patience: 101 / 200\n", "Last checkpoint at: Epoch 230 \n", "\n", "Epoch number 250\n", "Cost: 0.01056458899875771 e-3\n", "Patience: 103 / 200\n", "Last checkpoint at: Epoch 250 \n", "\n", "Epoch number 255\n", "Cost: 0.013950063650090088 e-3\n", "Patience: 106 / 200\n", "Last checkpoint at: Epoch 250 \n", "\n", "Epoch number 260\n", "Cost: 0.015239623812800694 e-3\n", "Patience: 109 / 200\n", "Last checkpoint at: Epoch 250 \n", "\n", "Epoch number 265\n", "Cost: 0.014050647958820845 e-3\n", "Patience: 112 / 200\n", "Last checkpoint at: Epoch 250 \n", "\n", "Epoch number 270\n", "Cost: 0.009441311336326799 e-3\n", "Patience: 112 / 200\n", "Last checkpoint at: Epoch 270 \n", "\n", "Epoch number 275\n", "Cost: 0.00812686008391617 e-3\n", "Patience: 116 / 200\n", "Last checkpoint at: Epoch 275 \n", "\n", "Epoch number 280\n", "Cost: 0.009064912048531968 e-3\n", "Patience: 118 / 200\n", "Last checkpoint at: Epoch 275 \n", "\n", "Epoch number 285\n", "Cost: 0.007350245905786808 e-3\n", "Patience: 119 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 290\n", "Cost: 0.009190695427025004 e-3\n", "Patience: 123 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 295\n", "Cost: 0.009242598896386706 e-3\n", "Patience: 126 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 300\n", "Cost: 0.009243554339921871 e-3\n", "Patience: 131 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 305\n", "Cost: 0.008543941756680069 e-3\n", "Patience: 134 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 310\n", "Cost: 0.008661668753700995 e-3\n", "Patience: 137 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 315\n", "Cost: 0.008509848796282003 e-3\n", "Patience: 142 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 320\n", "Cost: 0.009688999833745953 e-3\n", "Patience: 145 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 325\n", "Cost: 0.010096690673774302 e-3\n", "Patience: 148 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 330\n", "Cost: 0.008155997478597589 e-3\n", "Patience: 152 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 335\n", "Cost: 0.012822152828138837 e-3\n", "Patience: 156 / 200\n", "Last checkpoint at: Epoch 285 \n", "\n", "Epoch number 340\n", "Cost: 0.00638995292552244 e-3\n", "Patience: 159 / 200\n", "Last checkpoint at: Epoch 340 \n", "\n", "Epoch number 345\n", "Cost: 0.0066921474113924165 e-3\n", "Patience: 164 / 200\n", "Last checkpoint at: Epoch 340 \n", "\n", "Epoch number 350\n", "Cost: 0.006151222709028862 e-3\n", "Patience: 169 / 200\n", "Last checkpoint at: Epoch 350 \n", "\n", "Epoch number 355\n", "Cost: 0.006081407573606641 e-3\n", "Patience: 170 / 200\n", "Last checkpoint at: Epoch 355 \n", "\n", "Epoch number 360\n", "Cost: 0.007673800716494293 e-3\n", "Patience: 175 / 200\n", "Last checkpoint at: Epoch 355 \n", "\n", "Epoch number 365\n", "Cost: 0.0072596388893911585 e-3\n", "Patience: 180 / 200\n", "Last checkpoint at: Epoch 355 \n", "\n", "Epoch number 370\n", "Cost: 0.006717292966099427 e-3\n", "Patience: 184 / 200\n", "Last checkpoint at: Epoch 355 \n", "\n", "Epoch number 375\n", "Cost: 0.006316999443175093 e-3\n", "Patience: 189 / 200\n", "Last checkpoint at: Epoch 355 \n", "\n", "Epoch number 380\n", "Cost: 0.006750347554461382 e-3\n", "Patience: 193 / 200\n", "Last checkpoint at: Epoch 355 \n", "\n", "Epoch number 385\n", "Cost: 0.006520240363665544 e-3\n", "Patience: 198 / 200\n", "Last checkpoint at: Epoch 355 \n", "\n", "\n", " Early stopping at epoch 387 , difference: 9.317458766708246e-07\n", "Cost: 5.49251195054162e-06\n" ] } ], "source": [ "rnn.fit(minibatches, epochs = 5000, print_step=5)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEKCAYAAAA4t9PUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHq5JREFUeJzt3X2QXfV93/H35967u1o9IIxQXJAwkoPSRHj8uFH9VE9rUls4nshp5SDGTpiUKa0LtZ00cVA9JS5TpqVpjc0YO0MMMSaOBZXt8Y5LjB0LJ+PWkVhsbCOI7C3gsgabxSAhkPbh3vvtH+e3q6vLvecsu3t2L+LzmtnZc3/3nHO/9+zu/ezvPPyOIgIzM7P5qix3AWZm9sLmIDEzswVxkJiZ2YI4SMzMbEEcJGZmtiAOEjMzWxAHiZmZLYiDxMzMFsRBYmZmC1Jb7gKWwplnnhmbNm1a7jLMzF4w7rnnniciYv1c5n1RBMmmTZsYGRlZ7jLMzF4wJP14rvN615aZmS2Ig8TMzBbEQWJmZgviIDEzswVxkJiZ2YKUGiSStks6JGlU0pUdnh+QdFt6fr+kTal9naS7JD0j6RNty7xO0g/SMtdLUpnvwczM8pUWJJKqwA3AhcBW4GJJW9tmuxR4KiLOA64Drk3tE8B/BP6gw6o/BVwGbElf2xe/ejMzm6syeyTbgNGIeDAipoA9wI62eXYAt6TpvcAFkhQRz0bEt8gCZZaks4DTIuLbkd0j+LPAu8p6A9d/40f8zQ/Hy1q9mdkpocwg2QA80vJ4LLV1nCci6sARYF3BOscK1rloPvnNUf736BNlrd7M7JRQZpB0OnYR85hnXvNLukzSiKSR8fH59SqEyDo+ZmbWTZlBMgac0/J4I/Bot3kk1YC1wJMF69xYsE4AIuLGiBiKiKH16+c0XMxzSOAcMTPLV2aQ3A1skbRZUj+wCxhum2cYuCRN7wT2RU4XICIeA45Ken06W+t3gC8vfukZkd89MjOzEgdtjIi6pCuAO4EqcHNEHJR0NTASEcPATcCtkkbJeiK7ZpaX9DBwGtAv6V3A2yLifuB9wGeAQeCv0lcpJLlHYmZWoNTRfyPiDuCOtrarWqYngHd3WXZTl/YR4BWLV2V3WY/ESWJmlsdXtufxMRIzs0IOkhy+ZN7MrJiDJEd2jMRdEjOzPA6SHJLP2jIzK+IgySF8jMTMrIiDJIckn7VlZlbAQZKjImg6R8zMcjlIcvmCRDOzIg6SHNkts5wkZmZ5HCQ5fLDdzKyYgySHR/81MyvmIMkhfNaWmVkRB0kO90jMzIo5SHL4fiRmZsUcJDl8PxIzs2IOkgI+RmJmls9BkkPet2VmVshBksOj/5qZFXOQ5BC+H4mZWREHSQ73SMzMijlIcniIFDOzYg6SHNn9SMzMLI+DJEfWI3GUmJnlcZDk8TESM7NCDpIcvh2JmVkxB0kO37PdzKyYgySHz9oyMyvmIMnhYeTNzIo5SHL4xlZmZsUcJDncIzEzK+YgKeAcMTPLV2qQSNou6ZCkUUlXdnh+QNJt6fn9kja1PLc7tR+S9PaW9t+TdFDSfZI+L2lFifW7R2JmVqC0IJFUBW4ALgS2AhdL2to226XAUxFxHnAdcG1adiuwCzgf2A58UlJV0gbg/cBQRLwCqKb5ynkP+Mp2M7MiZfZItgGjEfFgREwBe4AdbfPsAG5J03uBCyQpte+JiMmIeAgYTesDqAGDkmrASuDRst5ApeJdW2ZmRcoMkg3AIy2Px1Jbx3kiog4cAdZ1WzYifgL8d+D/AY8BRyLia51eXNJlkkYkjYyPj8/rDfh+JGZmxcoMEnVoa/9U7jZPx3ZJLyHrrWwGzgZWSXpvpxePiBsjYigihtavX/88ym4pzmNtmZkVKjNIxoBzWh5v5Lm7oWbnSbuq1gJP5iz7a8BDETEeEdPAF4E3llI9vrLdzGwuygySu4EtkjZL6ic7KD7cNs8wcEma3gnsi2xf0jCwK53VtRnYAhwg26X1ekkr07GUC4AHSnsHvh+JmVmhWlkrjoi6pCuAO8nOrro5Ig5KuhoYiYhh4CbgVkmjZD2RXWnZg5JuB+4H6sDlEdEA9kvaC3wntX8XuLGs9+CztszMipUWJAARcQdwR1vbVS3TE8C7uyx7DXBNh/Y/Bv54cSvtTJ2O1JiZ2Ul8ZXsOHyMxMyvmIMnh+5GYmRVzkORwj8TMrJiDJIdH/zUzK+YgyeH7kZiZFXOQ5HGPxMyskIMkh/AQKWZmRRwkOeQkMTMr5CDJ4WMkZmbFHCQ5fNaWmVkxB0kODyNvZlbMQZLDN7YyMyvmIMnhHomZWTEHSQF3SMzM8jlIcsg3tjIzK+QgySFwl8TMrICDJIePkZiZFXOQ5BDQdI/EzCyXgySHJO/ZMjMr4CDJUfGV7WZmhRwkuXzWlplZEQdJjmysLUeJmVkeB0kOLXcBZmYvAA6SHB7918ysmIMkh+9HYmZWzEGSwz0SM7NiDpIcvrLdzKyYgySH70diZlbMQZLHPRIzs0IOkhzZ6L/LXYWZWW8rNUgkbZd0SNKopCs7PD8g6bb0/H5Jm1qe253aD0l6e0v76ZL2Svp7SQ9IekOJ9TtHzMwKlBYkkqrADcCFwFbgYklb22a7FHgqIs4DrgOuTctuBXYB5wPbgU+m9QF8HPhqRPwy8CrggdLeA76y3cysSJk9km3AaEQ8GBFTwB5gR9s8O4Bb0vRe4AJJSu17ImIyIh4CRoFtkk4D3gLcBBARUxFxuKw34LO2zMyKlRkkG4BHWh6PpbaO80REHTgCrMtZ9uXAOPDnkr4r6dOSVpVT/kyPpKy1m5mdGsoMkk5DVbV/LHebp1t7DXgt8KmIeA3wLPCcYy8Aki6TNCJpZHx8fO5Vn7wOX9luZlagzCAZA85pebwReLTbPJJqwFrgyZxlx4CxiNif2veSBctzRMSNETEUEUPr16+f1xtwj8TMrFiZQXI3sEXSZkn9ZAfPh9vmGQYuSdM7gX2RHd0eBnals7o2A1uAAxHxU+ARSf8wLXMBcH9p78BDpJiZFaqVteKIqEu6ArgTqAI3R8RBSVcDIxExTHbQ/FZJo2Q9kV1p2YOSbicLiTpweUQ00qr/HfC5FE4PAr9b1nuQB5I3MytUWpAARMQdwB1tbVe1TE8A7+6y7DXANR3a7wWGFrfSznxjKzOzYr6yPYfw6b9mZkUcJDk8jLyZWTEHSQ7f2MrMrJiDJIcETeeImVkuB0kO79oyMys2pyCRdOtc2k49PtxuZlZkrj2S81sfpJF4X7f45fSWinskZmaFcoMk3RPkKPBKSU+nr6PA48CXl6TCZeTRf83MiuUGSUT8l4hYA/xJRJyWvtZExLqI2L1ENS4b37PdzKzYXHdtfWVmuHZJ75X0UUnnllhXT3CPxMys2FyD5FPAMUmvAj4E/Bj4bGlV9QiP/mtmVmyuQVJPo/LuAD4eER8H1pRXVm+QvGvLzKzIXAdtPCppN/DbwD9OZ231lVdW73CMmJnlm2uP5CJgEviX6Z4gG4A/Ka2qHiFfRmJmVmhOQZLC43PAWknvBCYi4kVwjETOETOzAnO9sv23gANk9w75LWC/pJ1lFtYLfD8SM7Nicz1G8mHgVyPicQBJ64G/Jrtn+inLe7bMzIrN9RhJZSZEkp8/j2VfsDxoo5lZsbn2SL4q6U7g8+nxRbTdQvdUJPl+JGZmRXKDRNJ5wEsj4g8l/XPgzWR7fL5NdvD9lOYLEs3MihXtnvoYcBQgIr4YEb8fEb9H1hv5WNnFLTsPkWJmVqgoSDZFxPfbGyNiBNhUSkU9RE4SM7NCRUGyIue5wcUspBdlgzY6SczM8hQFyd2S/lV7o6RLgXvKKal3+BiJmVmxorO2Pgh8SdJ7OBEcQ0A/8JtlFtYLPIy8mVmx3CCJiJ8Bb5T0T4FXpOb/FRH7Sq+sB/jGVmZmxeZ0HUlE3AXcVXItPcc9EjOzYqf81ekL4WMkZmbFHCR5JMADN5qZ5XGQ5FD67hwxM+vOQZIjdUh8nMTMLEepQSJpu6RDkkYlXdnh+QFJt6Xn90va1PLc7tR+SNLb25arSvqupK+UWj/etWVmVqS0IEn3db8BuBDYClwsaWvbbJcCT0XEecB1wLVp2a3ALuB8YDvwybS+GR8AHiir9hnukZiZFSuzR7INGI2IByNiCtgD7GibZwdwS5reC1wgSal9T0RMRsRDwGhaH5I2Ar8OfLrE2gGozASJk8TMrKsyg2QD8EjL47HU1nGeiKgDR4B1Bct+DPgQ0Mx7cUmXSRqRNDI+Pj6vN6CZs7bcJzEz66rMIFGHtvZP5G7zdGyX9E7g8YgoHOcrIm6MiKGIGFq/fn1xtbnrWtDiZmantDKDZAw4p+XxRuDRbvNIqgFrgSdzln0T8BuSHibbVfZWSX9RRvFZTWWt2czs1FFmkNwNbJG0WVI/2cHz4bZ5hoFL0vROYF9kp0gNA7vSWV2bgS3AgYjYHREbI2JTWt++iHhvWW/gxFlbZb2CmdkL31zv2f68RURd0hXAnUAVuDkiDkq6GhiJiGHgJuBWSaNkPZFdadmDkm4H7gfqwOUR0Sir1m5OnLXlJDEz66a0IAGIiDvIbsvb2nZVy/QE8O4uy14DXJOz7m8C31yMOrvxle1mZsV8ZXsOX0diZlbMQZLDV7abmRVzkORwj8TMrJiDZA7cITEz685BkkPukpiZFXKQ5Jg9a8tJYmbWlYMkhzxoo5lZIQdJjhM9EjMz68ZBkkO+Z7uZWSEHSQ4fazczK+YgyeEhUszMijlI8vjGVmZmhRwkOWZvR+IcMTPrykGSw8dIzMyKOUhyzAza2PRBEjOzrhwkOXxBoplZMQdJDl+QaGZWzEGS40SPxFFiZtaNgyTHiRtbLXMhZmY9zEGSQyqex8zsxc5BkuPEWFvLXIiZWQ9zkOTw/UjMzIo5SHL49F8zs2IOkhy+st3MrJiDJMeJs7YcJWZm3ThIcrhHYmZWzEEyB+6QmJl15yDJIXmQFDOzIg6SHL5DoplZMQdJDh8jMTMrVmqQSNou6ZCkUUlXdnh+QNJt6fn9kja1PLc7tR+S9PbUdo6kuyQ9IOmgpA+UWr/H2jIzK1RakEiqAjcAFwJbgYslbW2b7VLgqYg4D7gOuDYtuxXYBZwPbAc+mdZXB/59RPwK8Hrg8g7rXMT3kH33le1mZt2V2SPZBoxGxIMRMQXsAXa0zbMDuCVN7wUuUHaEewewJyImI+IhYBTYFhGPRcR3ACLiKPAAsKGsN+BjJGZmxcoMkg3AIy2Px3juh/7sPBFRB44A6+aybNoN9hpg/yLWfBIPkWJmVqzMIOk0CHv7R3K3eXKXlbQa+ALwwYh4uuOLS5dJGpE0Mj4+PseSn7OW9MJOEjOzbsoMkjHgnJbHG4FHu80jqQasBZ7MW1ZSH1mIfC4ivtjtxSPixogYioih9evXz+sNuEdiZlaszCC5G9giabOkfrKD58Nt8wwDl6TpncC+yAa2GgZ2pbO6NgNbgAPp+MlNwAMR8dESawc6d4vMzOxktbJWHBF1SVcAdwJV4OaIOCjpamAkIobJQuFWSaNkPZFdadmDkm4H7ic7U+vyiGhIejPw28APJN2bXuo/RMQdZbwH39jKzKxYaUECkD7g72hru6plegJ4d5dlrwGuaWv7FkvYUfCNrczMivnK9hwzx0iazhEzs64cJDlOHGx3kpiZdeMgyTE7RMoy12Fm1sscJHl8+q+ZWSEHSY4TR/WdJGZm3ThIcvj0XzOzYg6SHL4/oplZMQdJjop7JGZmhRwkOXz6r5lZMQdJDu/aMjMr5iDJ49N/zcwKOUhyyPcjMTMr5CDJIe/bMjMr5CDJ4RwxMyvmIMnhCxLNzIo5SHLMnv7rPomZWVcOkhyzu7acI2ZmXTlIcpzokZiZWTcOklwzx0gcJWZm3ThIcrhHYmZWzEGSY/Z+JE4SM7OuHCQ5Zk//dZKYmXXlIMnhs7bMzIo5SHLIgzaamRVykOQ4MWijmZl14yDJMdMjabpLYmbWlYNkDpwjZmbdOUhyyOf/mpkVcpDkmD1G4hwxM+vKQZJjpkcy1WgubyFmZj2sVubKJW0HPg5UgU9HxH9te34A+CzwOuDnwEUR8XB6bjdwKdAA3h8Rd85lnYtpw0sGecnKPq768kH23jPGulX9rOirsqKvykBfhRW1anpcOfG9duL5gVqVVQNVzljZz+kr++mvObfN7NRTWpBIqgI3AP8MGAPuljQcEfe3zHYp8FREnCdpF3AtcJGkrcAu4HzgbOCvJf1SWqZonYvmtBV93Pav38An9o3y0BPP8vDPn2ViusnkdIOJepOp+vPrqazqr3L6yn4G+ipEZINBBtBoBhPTTV6+fhUDtQrVStYVmphuMNhXZaBW5ZnJ+mxgTUw3GahVqFTE8ak6jWZw1umDiOxoTrYrLtJrQK0qVg/UOD7dYEVflXojCIJGMxjsrzJQrYBEo9mk3gxW1KpUJAb6KhybarCqv8p0o0kzoCKYagSnD/YxUW/QV6nM1lVvBvXmif2Aannvanmglmdm2mdaKhWxfvUAz0zW6a9V6K9VGKhWeGayTl+1Mrv+/lqFgRTMzWYw0FelryqOTtSpVcSaFX3Um02EWDvYx9qVfawZqFGptFZlZouhzB7JNmA0Ih4EkLQH2AG0fujvAD6SpvcCn1A2LskOYE9ETAIPSRpN62MO61xUv/TSNVx/8Ws6PtdsBpP1JhPTDSbqDSanm0zUG0xMp7bpBs9ONnjq2BRPPjvFkePTHD42zUS9QUWiouwDtCJRqYiHn3iWoxP12dONV9SqPPHMFBPTDVavqPHUsSbHphoM1CpMN5opCGoI+P7YEbJtAiCU1i3BVL3Js1NZKE1MN2aDqr+WBcVMINYqoloRU43mKXlcqCJYO9jHaYN9HJ2os23TGZx9+iDHpxv0V8WxqQYSnLtuFdWKqEocOT5NtaIU4A36qtnPqirNbsfDx6Y5c3U/Z58+yHQjaEQw2FdlZX+VqUaTnx2Z4Pyz1/KTw8f4+v2Ps/N1G1k72HdSbVL6avvZQfZPQLUiJqYbNJpBrSqmG8FTx6Z42RkrOT7VoFLJfp+qEkq/W9XKydMVZeuuSGm+bBig6UaT6UaTlf2dPw6azaAZQa26/D3qiens91/yPwS9pMwg2QA80vJ4DPhH3eaJiLqkI8C61P53bctuSNNF61wylYoY7K8y2F9drhIWxcww+TN/nNPpmNBMj2ii3qS/WkGCeiNm//Nf0Vdlqt5kstFgYqpJX03UKtmHzUnjk3WenA2r1nmn6k3Gj05y2mAfU/UmU40mk9NNBvurNJrN2V7JdL3JdCOoKPs5TKYe4poVNSbrTY5P1alVKjQjeHqizuFjUxw+Ns3h41McOZ71Wg489CR/+6NxVvbXmG40qVZEvdHk6Yn6bD3Vimg081NVen4nZHzhO2Nzn7lkM7VLsLq/xnQz+wdluhFUK2KgVqGewnGgVqGRQkUpjCbqDfqrFfqqFVLupRBsDUSdFIytQXnk+HS2W7hWpdG2Edu3aUTw82enOHN1PwO16myPvhlBM/W+K4JjUw0G+6vUKjrpdyzixO9fJQXqTC++3gyazZgN3ueGeva3Ualk7RWdeF/MI9PmG4PPN0DPWNnP7f/mDfN8tbkrM0g6veP2P7du83Rr7/QvUcc/YUmXAZcBvOxlL+tepT3nl7Mv/ec58311y3+ifSkzX7KqHyCF6Mn/XS/UuetWLer6no9mM5hqNGlG2vWX3vBkvclg2n0381wjsg+n01bU+OnTE/z8mSn6axUqguNTTY5N1WlEsHawj7GnjrNmoMZ5v7Caex85zMnZdOJDLvve+jg4OlEnItt9V6todvfeQK3C+NFJ1g72EQGNiJYP1Ugf+idPNyNSD2PmAzgLDJH1vmpVUatkX82AyXrW2+mrVGZ7s5WKZtcz88/EdNpdOvvBnT7kn/t+Wt5vwKqBGhP1Bo1GkP0PcvLvYvvn5pmrB3j08HGaES29es0u22wGKweqHJ/Kem8zgTCzrtZhj2a2FUBfVemfhlR7h7pnH88G2PzuVTTvzv48FlyzotTD4LPKfJUx4JyWxxuBR7vMMyapBqwFnixYtmidAETEjcCNAENDQ6fgjhorQ6UiVlSe28Oc2a3T3+UYy1lrBzlr7WDX9Z5/9trZ6bed/w8WWKVZbylzp+fdwBZJmyX1kx08H26bZxi4JE3vBPZFFvHDwC5JA5I2A1uAA3Ncp5mZLaHSeiTpmMcVwJ1kp+reHBEHJV0NjETEMHATcGs6mP4kWTCQ5rud7CB6Hbg8IhoAndZZ1nswM7NiejHcj3xoaChGRkaWuwwzsxcMSfdExNBc5l3+8/nMzOwFzUFiZmYL4iAxM7MFcZCYmdmCOEjMzGxBXhRnbUkaB348z8XPBJ5YxHIWk2ubv16uz7XNXy/X90Kr7dyIWD+XhV8UQbIQkkbmegrcUnNt89fL9bm2+evl+k7l2rxry8zMFsRBYmZmC+IgKXbjcheQw7XNXy/X59rmr5frO2Vr8zESMzNbEPdIzMxsQRwkXUjaLumQpFFJVy53PQCSHpb0A0n3ShpJbWdI+rqkH6XvL1miWm6W9Lik+1raOtaizPVpW35f0muXobaPSPpJ2nb3SnpHy3O7U22HJL295NrOkXSXpAckHZT0gdTeK9uuW33Lvv0krZB0QNL3Um3/KbVvlrQ/bbvb0i0mSLehuC3Vtl/SpmWo7TOSHmrZbq9O7Uv6c02vWZX0XUlfSY8Xb7tldwPzV+sX2RD1/xd4OdAPfA/Y2gN1PQyc2db234Ar0/SVwLVLVMtbgNcC9xXVArwD+Cuy29+9Hti/DLV9BPiDDvNuTT/fAWBz+rlXS6ztLOC1aXoN8MNUQ69su271Lfv2S9tgdZruA/anbXI7sCu1/ynwvjT9b4E/TdO7gNtK3G7davsMsLPD/Ev6c02v+fvAXwJfSY8Xbbu5R9LZNmA0Ih6MiClgD7BjmWvqZgdwS5q+BXjXUrxoRPwt2T1k5lLLDuCzkfk74HRJZy1xbd3sAPZExGREPASMkv38y6rtsYj4Tpo+CjwAbKB3tl23+rpZsu2XtsEz6WFf+grgrcDe1N6+7Wa26V7gAul53vR84bV1s6Q/V0kbgV8HPp0ei0Xcbg6SzjYAj7Q8HiP/j2mpBPA1Sfcouyc9wEsj4jHIPgSAX1i26rrX0ivb84q0G+Hmll2Ay1Zb2mXwGrL/Xntu27XVBz2w/dLumXuBx4Gvk/WADkdEvcPrz9aWnj8CrFuq2iJiZrtdk7bbdZIG2mvrUHcZPgZ8CGimx+tYxO3mIOmsU/r2wultb4qI1wIXApdLestyFzRHvbA9PwX8IvBq4DHgf6T2ZalN0mrgC8AHI+LpvFk7tC1HfT2x/SKiERGvBjaS9Xx+Jef1l7U2Sa8AdgO/DPwqcAbwR0tdm6R3Ao9HxD2tzTmv/7xrc5B0Ngac0/J4I/DoMtUyKyIeTd8fB75E9of0s5kucfr++PJV2LWWZd+eEfGz9IfeBP6ME7tflrw2SX1kH9Kfi4gvpuae2Xad6uul7ZfqOQx8k+z4wumSZm4b3vr6s7Wl59cy912ei1Hb9rSrMCJiEvhzlme7vQn4DUkPk+2mfytZD2XRtpuDpLO7gS3prIZ+sgNOw8tZkKRVktbMTANvA+5LdV2SZrsE+PLyVAg5tQwDv5POVHk9cGRmN85Sadv//Jtk226mtl3pTJXNwBbgQIl1CLgJeCAiPtryVE9su2719cL2k7Re0ulpehD4NbJjOHcBO9Ns7dtuZpvuBPZFOoK8RLX9fcs/ByI7BtG63Zbk5xoRuyNiY0RsIvss2xcR72Ext1vZZwq8UL/Izqr4Idk+2A/3QD0vJzs75nvAwZmayPZdfgP4Ufp+xhLV83myXRzTZP/BXNqtFrKu8g1pW/4AGFqG2m5Nr/399IdyVsv8H061HQIuLLm2N5PtJvg+cG/6ekcPbbtu9S379gNeCXw31XAfcFXL38YBsgP9/xMYSO0r0uPR9PzLl6G2fWm73Qf8BSfO7FrSn2tLnf+EE2dtLdp285XtZma2IN61ZWZmC+IgMTOzBXGQmJnZgjhIzMxsQRwkZma2IA4Ss0UgqdEywuu9WsQRoyVtUstIxma9plY8i5nNwfHIhscwe9Fxj8SsRMruIXNtulfFAUnnpfZzJX0jDeb3DUkvS+0vlfQlZfe1+J6kN6ZVVSX9mbJ7XXwtXT1t1hMcJGaLY7Bt19ZFLc89HRHbgE+QjXFEmv5sRLwS+BxwfWq/HvibiHgV2T1VDqb2LcANEXE+cBj4FyW/H7M585XtZotA0jMRsbpD+8PAWyPiwTQY4k8jYp2kJ8iGGZlO7Y9FxJmSxoGNkQ3yN7OOTWTDkm9Jj/8I6IuI/1z+OzMr5h6JWfmiy3S3eTqZbJlu4OOb1kMcJGblu6jl+7fT9P8hG4kV4D3At9L0N4D3weyNkk5bqiLN5sv/1ZgtjsF0d7wZX42ImVOAByTtJ/vH7eLU9n7gZkl/CIwDv5vaPwDcKOlSsp7H+8hGMjbrWT5GYlaidIxkKCKeWO5azMriXVtmZrYg7pGYmdmCuEdiZmYL4iAxM7MFcZCYmdmCOEjMzGxBHCRmZrYgDhIzM1uQ/w/ptlV1wVOjOwAAAABJRU5ErkJggg==\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#Plot the loss\n", "\n", "plt.plot(rnn.loss_list)\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Cost\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "#save in a folder that describes the model\n", "folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", "rnn.save(folder)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm_leaky_relu_5l_[50,40,30,20,10]c/rnn_basic\n" ] } ], "source": [ "#folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", "#rnn.load(folder)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "###test_input.shape###" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "#Here I predict based on my test set\n", "\n", "test_pred = rnn.predict(test_input)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0.00610282 0.00100984 0.02600916]\n", " [ 0.01632101 -0.01520294 0.02987524]\n", " [ 0.06068288 0.00697896 0.06441782]\n", " [-0.0119639 -0.04535145 0.07225598]\n", " [ 0.04132241 0.01145548 0.05150088]\n", " [-0.03290992 0.10355402 0.09310361]\n", " [ 0.00265487 0.04124176 0.08941123]]\n" ] } ], "source": [ "#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n", "\n", "#scaler_inv(test_input, scalerfunc = func)[0,:,:]\n", "\n", "\n", "diff = scaler_inv(test_pred, scalerfunc = func)-scaler_inv(test_target, scalerfunc = func )\n", "\n", "print(diff[0,:,:])" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7.513113e-06" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Here I evaluate my model on the test set based on mean_squared_error\n", "\n", "rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }