{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib as mpl\n", "import random\n", "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops\n", "from sklearn import preprocessing" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#import data as array\n", "# 8 hits with x,y,z\n", "\n", "testset = pd.read_pickle('matched_8hittracks.pkl')\n", "#print(testset)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#Check testset with arbitrary particle\n", "\n", "tset = np.array(testset)\n", "tset = tset.astype('float32')\n", "#print(tset.shape)\n", "#for i in range(8):\n", " #print(tset[1,3*i:(3*i+3)])\n", "#print(tset[0,:])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n", "\n", "def reshapor(arr_orig):\n", " timesteps = int(arr_orig.shape[1]/3)\n", " number_examples = int(arr_orig.shape[0])\n", " arr = np.zeros((number_examples, timesteps, 3))\n", " \n", " for i in range(number_examples):\n", " for t in range(timesteps):\n", " arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n", " \n", " return arr\n", "\n", "def reshapor_inv(array_shaped):\n", " timesteps = int(array_shaped.shape[1])\n", " num_examples = int(array_shaped.shape[0])\n", " arr = np.zeros((num_examples, timesteps*3))\n", " \n", " for i in range(num_examples):\n", " for t in range(timesteps):\n", " arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "### create the training set and the test set###\n", "\n", "def create_random_sets(dataset, train_to_total_ratio):\n", " #shuffle the dataset\n", " num_examples = dataset.shape[0]\n", " p = np.random.permutation(num_examples)\n", " dataset = dataset[p,:]\n", " \n", " #evaluate siye of training and test set and initialize them\n", " train_set_size = np.int(num_examples*train_to_total_ratio)\n", " test_set_size = num_examples - train_set_size\n", " \n", " train_set = np.zeros((train_set_size, dataset.shape[1]))\n", " test_set = np.zeros((test_set_size, dataset.shape[1]))\n", " \n", "\n", " #fill train and test sets\n", " for i in range(num_examples):\n", " if train_set_size > i:\n", " train_set[i,:] += dataset[i,:]\n", " else:\n", " test_set[i - train_set_size,:] += dataset[i,:]\n", " \n", " return train_set, test_set\n", " " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "train_set, test_set = create_random_sets(tset, 0.99)\n", "\n", "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n", "#print(test_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "#Normalize the data advanced version with scikit learn\n", "\n", "#set the transormation based on training set\n", "def set_min_max_scalor(arr, feature_range= (-1,1)):\n", " min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n", " if len(arr.shape) == 3:\n", " arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr))) \n", " else:\n", " arr = min_max_scalor.fit_transform(arr)\n", " return min_max_scalor\n", "\n", "min_max_scalor = set_min_max_scalor(train_set)\n", "\n", "\n", "#transform data\n", "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n", " \n", " if len(arr.shape) == 3:\n", " arr = reshapor(min_max_scalor.transform(reshapor_inv(arr))) \n", " else:\n", " arr = min_max_scalor.transform(arr)\n", " \n", " return arr\n", " \n", "#inverse transformation\n", "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n", " if len(arr.shape) == 3:\n", " arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n", " else:\n", " arr = min_max_scalor.inverse_transform(arr)\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "train_set = reshapor(train_set)\n", "test_set = reshapor(test_set)\n", "\n", "#print(train_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "train_set = min_max_scaler(train_set)\n", "test_set = min_max_scaler(test_set)\n", "\n", "#print(train_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "#train_set = min_max_scaler_inv(train_set)\n", "\n", "#print(train_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "###create random mini_batches###\n", "\n", "\n", "def unison_shuffled_copies(a, b):\n", " assert a.shape[0] == b.shape[0]\n", " p = np.random.permutation(a.shape[0])\n", " return a[p,:,:], b[p,:,:]\n", "\n", "def random_mini_batches(inputt, target, minibatch_size = 500):\n", " \n", " num_examples = inputt.shape[0]\n", " \n", " \n", " #Number of complete batches\n", " \n", " number_of_batches = int(num_examples/minibatch_size)\n", " minibatches = []\n", " \n", " #shuffle particles\n", " _i, _t = unison_shuffled_copies(inputt, target)\n", " #print(_t.shape)\n", " \n", " \n", " for i in range(number_of_batches):\n", " \n", " minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatches.append((minibatch_train, minibatch_true))\n", " \n", " \n", " minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n", " \n", " \n", " return minibatches\n", " " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "#Create random minibatches of train and test set with input and target array\n", "\n", "\n", "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:], minibatch_size = 1000)\n", "#_train, _target = minibatches[0]\n", "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n", "#print(train[0,:,:], target[0,:,:])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "#minibatches = random_mini_batches(inputt_train, target_train)\n", "\n", "\n", "#_inputt, _target = minibatches[int(inputt_train.shape[0]/500)]\n", "\n", "#print(len(minibatches))\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class RNNPlacePrediction():\n", " \n", " \n", " def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\"):\n", " \n", " self.nsteps = time_steps\n", " self.future_steps = future_steps\n", " self.ninputs = ninputs\n", " self.ncells = ncells\n", " self.num_output = num_output\n", " self._ = cell_type #later used to create folder name\n", " self.__ = activation #later used to create folder name\n", " \n", " #### The input is of shape (num_examples, time_steps, ninputs)\n", " #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n", " self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", " self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", "\n", " \n", " #Check if activation function valid and set activation\n", " if activation==\"relu\":\n", " self.activation = tf.nn.relu\n", " \n", " elif activation==\"tanh\":\n", " self.activation = tf.nn.tanh\n", " \n", " elif activation==\"leaky_relu\":\n", " self.activation = tf.nn.leaky_relu\n", " \n", " elif activation==\"elu\":\n", " self.activation = tf.nn.elu\n", " \n", " else:\n", " raise ValueError(\"Wrong rnn avtivation function: {}\".format(activation))\n", " \n", " \n", " \n", " #Check if cell type valid and set cell_type\n", " if cell_type==\"basic_rnn\":\n", " self.cell_type = tf.contrib.rnn.BasicRNNCell\n", " \n", " elif cell_type==\"lstm\":\n", " self.cell_type = tf.contrib.rnn.BasicLSTMCell\n", " \n", " elif cell_type==\"GRU\":\n", " self.cell_type = tf.contrib.rnn.GRUCell\n", " \n", " else:\n", " raise ValueError(\"Wrong rnn cell type: {}\".format(cell_type))\n", " \n", " \n", " #Check Input of ncells \n", " if (type(self.ncells) == int):\n", " self.ncells = [self.ncells]\n", " \n", " if (type(self.ncells) != list):\n", " raise ValueError(\"Wrong type of Input for ncells\")\n", " \n", " for _ in range(len(self.ncells)):\n", " if type(self.ncells[_]) != int:\n", " raise ValueError(\"Wrong type of Input for ncells\")\n", " \n", " self.activationlist = []\n", " for _ in range(len(self.ncells)-1):\n", " self.activationlist.append(self.activation)\n", " self.activationlist.append(tf.nn.tanh)\n", " \n", " self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n", " for layer in range(len(self.ncells))])\n", " \n", " \n", " #### I now define the output\n", " self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n", " \n", " \n", " \n", " \n", " \n", " self.sess = tf.Session()\n", " \n", " def set_cost_and_functions(self, LR=0.001):\n", " #### I define here the function that unrolls the RNN cell\n", " self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n", " #### I define the cost function as the mean_squared_error (distance of predicted point to target)\n", " self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output)) \n", " \n", " #### the rest proceed as usual\n", " self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n", " #### Variable initializer\n", " self.init = tf.global_variables_initializer()\n", " self.saver = tf.train.Saver()\n", " self.sess.run(self.init)\n", " \n", " \n", " def save(self, filename=\"./rnn_model/rnn_basic\"):\n", " self.saver.save(self.sess, filename)\n", " \n", " \n", " def load(self, filename=\"./rnn_model/rnn_basic\"):\n", " self.saver.restore(self.sess, filename)\n", " \n", " \n", " \n", " def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n", " self.loss_list = []\n", " patience_cnt = 0\n", " epoche_save = 0\n", " \n", " folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n", " \n", " for iep in range(epochs):\n", " loss = 0\n", " \n", " batches = len(minibatches)\n", " #Here I iterate over the batches\n", " for batch in range(batches):\n", " #### Here I train the RNNcell\n", " #### The X is the time series, the Y is shifted by 1 time step\n", " train, target = minibatches[batch]\n", " self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n", " \n", " \n", " loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n", " \n", " #Normalize loss over number of batches and scale it back before normaliziation\n", " loss /= batches\n", " self.loss_list.append(loss)\n", " \n", " #print(loss)\n", " \n", " #Here I create the checkpoint if the perfomance is better\n", " if iep > 1 and iep%checkpoint == 0 and self.loss_list[iep] < self.loss_list[epoche_save]:\n", " #print(\"Checkpoint created at epoch: \", iep)\n", " self.save(folder)\n", " epoche_save = iep\n", " \n", " #early stopping with patience\n", " if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 2/1000000:\n", " patience_cnt += 1\n", " #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n", " \n", " if patience_cnt + 1 > patience:\n", " print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n", " print(\"Cost: \",loss)\n", " break\n", " \n", " #Note that the loss here is multiplied with 1000 for easier reading\n", " if iep%print_step==0:\n", " print(\"Epoch number \",iep)\n", " print(\"Cost: \",loss*1000, \"e-3\")\n", " print(\"Patience: \",patience_cnt, \"/\", patience)\n", " print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n", " \n", " #Set model back to the last checkpoint if performance was better\n", " if self.loss_list[epoche_save] < self.loss_list[iep]:\n", " self.load(folder)\n", " print(\"\\n\", \"Last checkpoint at epoch \", epoche_save, \" loaded\")\n", " print(\"Performance at last checkpoint is \" ,self.loss_list[iep] - self.loss_list[epoche_save], \" better\" )\n", " \n", " \n", " \n", " def predict(self, x):\n", " return self.sess.run(self.output, feed_dict={self.X:x})\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "timesteps = 7\n", "future_steps = 1\n", "\n", "ninputs = 3\n", "\n", "#ncells as int or list of int\n", "ncells = [50, 40, 30, 20, 10]\n", "\n", "num_output = 3" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use the retry module or similar alternatives.\n" ] } ], "source": [ "tf.reset_default_graph()\n", "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n", " ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "rnn.set_cost_and_functions()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch number 0\n", "Cost: 3770.231458734959 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 0 \n", "\n", "Epoch number 5\n", "Cost: 1649.7736788810569 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 5 \n", "\n", "Epoch number 10\n", "Cost: 625.2868418046768 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 10 \n", "\n", "Epoch number 15\n", "Cost: 294.9610768639027 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 15 \n", "\n", "Epoch number 20\n", "Cost: 209.0108957379422 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 20 \n", "\n", "Epoch number 25\n", "Cost: 174.1866168982171 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 25 \n", "\n", "Epoch number 30\n", "Cost: 149.8719225538538 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 30 \n", "\n", "Epoch number 35\n", "Cost: 131.33942407179387 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 35 \n", "\n", "Epoch number 40\n", "Cost: 115.83642023516462 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 40 \n", "\n", "Epoch number 45\n", "Cost: 107.55172256935151 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 45 \n", "\n", "Epoch number 50\n", "Cost: 98.54952309359895 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 50 \n", "\n", "Epoch number 55\n", "Cost: 95.66065657170529 e4\n", "Patience: 0 / 200\n", "Last checkpoint at: Epoch 55 \n", "\n", "Epoch number 60\n", "Cost: 90.34742145462239 e4\n", "Patience: 1 / 200\n", "Last checkpoint at: Epoch 60 \n", "\n", "Epoch number 65\n", "Cost: 84.77292855844853 e4\n", "Patience: 2 / 200\n", "Last checkpoint at: Epoch 65 \n", "\n", "Epoch number 70\n", "Cost: 78.54001398416275 e4\n", "Patience: 3 / 200\n", "Last checkpoint at: Epoch 70 \n", "\n", "Epoch number 75\n", "Cost: 75.23123551397882 e4\n", "Patience: 3 / 200\n", "Last checkpoint at: Epoch 75 \n", "\n", "Epoch number 80\n", "Cost: 73.33986362085697 e4\n", "Patience: 4 / 200\n", "Last checkpoint at: Epoch 80 \n", "\n", "Epoch number 85\n", "Cost: 69.12997319422504 e4\n", "Patience: 5 / 200\n", "Last checkpoint at: Epoch 85 \n", "\n", "Epoch number 90\n", "Cost: 65.79162087291479 e4\n", "Patience: 5 / 200\n", "Last checkpoint at: Epoch 90 \n", "\n", "Epoch number 95\n", "Cost: 61.82488113483216 e4\n", "Patience: 6 / 200\n", "Last checkpoint at: Epoch 95 \n", "\n", "Epoch number 100\n", "Cost: 59.33671109774646 e4\n", "Patience: 8 / 200\n", "Last checkpoint at: Epoch 100 \n", "\n", "Epoch number 105\n", "Cost: 57.19678456637453 e4\n", "Patience: 9 / 200\n", "Last checkpoint at: Epoch 105 \n", "\n", "Epoch number 110\n", "Cost: 55.66507161773266 e4\n", "Patience: 10 / 200\n", "Last checkpoint at: Epoch 110 \n", "\n", "Epoch number 115\n", "Cost: 54.365597526602286 e4\n", "Patience: 13 / 200\n", "Last checkpoint at: Epoch 115 \n", "\n", "Epoch number 120\n", "Cost: 52.487826807067755 e4\n", "Patience: 14 / 200\n", "Last checkpoint at: Epoch 120 \n", "\n", "Epoch number 125\n", "Cost: 51.60155072015651 e4\n", "Patience: 17 / 200\n", "Last checkpoint at: Epoch 125 \n", "\n", "Epoch number 130\n", "Cost: 51.004822227232 e4\n", "Patience: 20 / 200\n", "Last checkpoint at: Epoch 130 \n", "\n", "Epoch number 135\n", "Cost: 49.656663347590474 e4\n", "Patience: 22 / 200\n", "Last checkpoint at: Epoch 135 \n", "\n", "Epoch number 140\n", "Cost: 49.04315717756114 e4\n", "Patience: 26 / 200\n", "Last checkpoint at: Epoch 140 \n", "\n", "Epoch number 145\n", "Cost: 48.333713487583275 e4\n", "Patience: 29 / 200\n", "Last checkpoint at: Epoch 145 \n", "\n", "Epoch number 150\n", "Cost: 47.4689517447606 e4\n", "Patience: 33 / 200\n", "Last checkpoint at: Epoch 150 \n", "\n", "Epoch number 155\n", "Cost: 46.82262457827938 e4\n", "Patience: 38 / 200\n", "Last checkpoint at: Epoch 155 \n", "\n", "Epoch number 160\n", "Cost: 46.189470573308625 e4\n", "Patience: 43 / 200\n", "Last checkpoint at: Epoch 160 \n", "\n", "Epoch number 165\n", "Cost: 45.566867759570165 e4\n", "Patience: 48 / 200\n", "Last checkpoint at: Epoch 165 \n", "\n", "Epoch number 170\n", "Cost: 45.00874754120695 e4\n", "Patience: 53 / 200\n", "Last checkpoint at: Epoch 170 \n", "\n", "Epoch number 175\n", "Cost: 44.46649339367101 e4\n", "Patience: 58 / 200\n", "Last checkpoint at: Epoch 175 \n", "\n", "Epoch number 180\n", "Cost: 43.92929008587244 e4\n", "Patience: 63 / 200\n", "Last checkpoint at: Epoch 180 \n", "\n", "Epoch number 185\n", "Cost: 43.44754183585656 e4\n", "Patience: 68 / 200\n", "Last checkpoint at: Epoch 185 \n", "\n", "Epoch number 190\n", "Cost: 42.95319576371223 e4\n", "Patience: 73 / 200\n", "Last checkpoint at: Epoch 190 \n", "\n", "Epoch number 195\n", "Cost: 42.52819289909082 e4\n", "Patience: 78 / 200\n", "Last checkpoint at: Epoch 195 \n", "\n", "Epoch number 200\n", "Cost: 41.93341770665126 e4\n", "Patience: 83 / 200\n", "Last checkpoint at: Epoch 200 \n", "\n", "Epoch number 205\n", "Cost: 41.554861285902085 e4\n", "Patience: 88 / 200\n", "Last checkpoint at: Epoch 205 \n", "\n", "Epoch number 210\n", "Cost: 41.090038733834284 e4\n", "Patience: 93 / 200\n", "Last checkpoint at: Epoch 210 \n", "\n", "Epoch number 215\n", "Cost: 40.845294889221165 e4\n", "Patience: 98 / 200\n", "Last checkpoint at: Epoch 215 \n", "\n", "Epoch number 220\n", "Cost: 40.25109122170412 e4\n", "Patience: 103 / 200\n", "Last checkpoint at: Epoch 220 \n", "\n", "Epoch number 225\n", "Cost: 39.58158948002977 e4\n", "Patience: 108 / 200\n", "Last checkpoint at: Epoch 225 \n", "\n", "Epoch number 230\n", "Cost: 38.97598008327979 e4\n", "Patience: 113 / 200\n", "Last checkpoint at: Epoch 230 \n", "\n", "Epoch number 235\n", "Cost: 38.51150915502234 e4\n", "Patience: 118 / 200\n", "Last checkpoint at: Epoch 235 \n", "\n", "Epoch number 240\n", "Cost: 38.299499218292695 e4\n", "Patience: 123 / 200\n", "Last checkpoint at: Epoch 240 \n", "\n", "Epoch number 245\n", "Cost: 37.74655878821269 e4\n", "Patience: 128 / 200\n", "Last checkpoint at: Epoch 245 \n", "\n", "Epoch number 250\n", "Cost: 37.40582783567778 e4\n", "Patience: 133 / 200\n", "Last checkpoint at: Epoch 250 \n", "\n", "Epoch number 255\n", "Cost: 37.24810196720856 e4\n", "Patience: 138 / 200\n", "Last checkpoint at: Epoch 255 \n", "\n", "Epoch number 260\n", "Cost: 37.280498320197175 e4\n", "Patience: 143 / 200\n", "Last checkpoint at: Epoch 255 \n", "\n", "Epoch number 265\n", "Cost: 36.25094043487247 e4\n", "Patience: 147 / 200\n", "Last checkpoint at: Epoch 265 \n", "\n", "Epoch number 270\n", "Cost: 36.03106825315255 e4\n", "Patience: 152 / 200\n", "Last checkpoint at: Epoch 270 \n", "\n", "Epoch number 275\n", "Cost: 35.67509779191398 e4\n", "Patience: 156 / 200\n", "Last checkpoint at: Epoch 275 \n", "\n", "Epoch number 280\n", "Cost: 35.42137842506487 e4\n", "Patience: 161 / 200\n", "Last checkpoint at: Epoch 280 \n", "\n", "Epoch number 285\n", "Cost: 35.79035718390282 e4\n", "Patience: 164 / 200\n", "Last checkpoint at: Epoch 280 \n", "\n", "Epoch number 290\n", "Cost: 33.758991754594 e4\n", "Patience: 165 / 200\n", "Last checkpoint at: Epoch 290 \n", "\n", "Epoch number 295\n", "Cost: 34.39420328891658 e4\n", "Patience: 166 / 200\n", "Last checkpoint at: Epoch 290 \n", "\n", "Epoch number 300\n", "Cost: 33.66679522862777 e4\n", "Patience: 166 / 200\n", "Last checkpoint at: Epoch 300 \n", "\n", "Epoch number 305\n", "Cost: 34.23552023880976 e4\n", "Patience: 167 / 200\n", "Last checkpoint at: Epoch 300 \n", "\n", "Epoch number 310\n", "Cost: 33.27848409560132 e4\n", "Patience: 168 / 200\n", "Last checkpoint at: Epoch 310 \n", "\n", "Epoch number 315\n", "Cost: 32.72916789741275 e4\n", "Patience: 171 / 200\n", "Last checkpoint at: Epoch 315 \n", "\n", "Epoch number 320\n", "Cost: 32.42362023113255 e4\n", "Patience: 173 / 200\n", "Last checkpoint at: Epoch 320 \n", "\n", "Epoch number 325\n", "Cost: 33.13556412591579 e4\n", "Patience: 173 / 200\n", "Last checkpoint at: Epoch 320 \n", "\n", "Epoch number 330\n", "Cost: 34.35548811041294 e4\n", "Patience: 173 / 200\n", "Last checkpoint at: Epoch 320 \n", "\n", "Epoch number 335\n", "Cost: 31.17884152588692 e4\n", "Patience: 174 / 200\n", "Last checkpoint at: Epoch 335 \n", "\n", "Epoch number 340\n", "Cost: 33.64366251341206 e4\n", "Patience: 174 / 200\n", "Last checkpoint at: Epoch 335 \n", "\n", "Epoch number 345\n", "Cost: 32.388941939682404 e4\n", "Patience: 175 / 200\n", "Last checkpoint at: Epoch 335 \n", "\n", "Epoch number 350\n", "Cost: 29.8897856648298 e4\n", "Patience: 175 / 200\n", "Last checkpoint at: Epoch 350 \n", "\n", "Epoch number 355\n", "Cost: 30.779531522792706 e4\n", "Patience: 176 / 200\n", "Last checkpoint at: Epoch 350 \n", "\n", "Epoch number 360\n", "Cost: 32.77950439641767 e4\n", "Patience: 177 / 200\n", "Last checkpoint at: Epoch 350 \n", "\n", "Epoch number 365\n", "Cost: 34.279519781232516 e4\n", "Patience: 177 / 200\n", "Last checkpoint at: Epoch 350 \n", "\n", "Epoch number 370\n", "Cost: 29.02430596147129 e4\n", "Patience: 177 / 200\n", "Last checkpoint at: Epoch 370 \n", "\n", "Epoch number 375\n", "Cost: 31.375054398828997 e4\n", "Patience: 178 / 200\n", "Last checkpoint at: Epoch 370 \n", "\n", "Epoch number 380\n", "Cost: 33.813590144223355 e4\n", "Patience: 178 / 200\n", "Last checkpoint at: Epoch 370 \n", "\n", "Epoch number 385\n", "Cost: 28.6719871268786 e4\n", "Patience: 178 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 390\n", "Cost: 31.848519872081408 e4\n", "Patience: 179 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 395\n", "Cost: 29.007866582337847 e4\n", "Patience: 181 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 400\n", "Cost: 33.16965553552863 e4\n", "Patience: 181 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 405\n", "Cost: 32.650657305295795 e4\n", "Patience: 181 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 410\n", "Cost: 28.816359365319318 e4\n", "Patience: 181 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch number 415\n", "Cost: 29.141941761716886 e4\n", "Patience: 181 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 420\n", "Cost: 30.577135856877614 e4\n", "Patience: 182 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 425\n", "Cost: 29.400000695456217 e4\n", "Patience: 183 / 200\n", "Last checkpoint at: Epoch 385 \n", "\n", "Epoch number 430\n", "Cost: 26.99479599423865 e4\n", "Patience: 183 / 200\n", "Last checkpoint at: Epoch 430 \n", "\n", "Epoch number 435\n", "Cost: 30.304402994744958 e4\n", "Patience: 184 / 200\n", "Last checkpoint at: Epoch 430 \n", "\n", "Epoch number 440\n", "Cost: 29.647010675770172 e4\n", "Patience: 184 / 200\n", "Last checkpoint at: Epoch 430 \n", "\n", "Epoch number 445\n", "Cost: 27.00613232012442 e4\n", "Patience: 185 / 200\n", "Last checkpoint at: Epoch 430 \n", "\n", "Epoch number 450\n", "Cost: 27.036350567210864 e4\n", "Patience: 186 / 200\n", "Last checkpoint at: Epoch 430 \n", "\n", "Epoch number 455\n", "Cost: 27.08697458729148 e4\n", "Patience: 187 / 200\n", "Last checkpoint at: Epoch 430 \n", "\n", "Epoch number 460\n", "Cost: 28.004820329791055 e4\n", "Patience: 188 / 200\n", "Last checkpoint at: Epoch 430 \n", "\n", "Epoch number 465\n", "Cost: 26.3666685551722 e4\n", "Patience: 188 / 200\n", "Last checkpoint at: Epoch 465 \n", "\n", "Epoch number 470\n", "Cost: 26.36444576560183 e4\n", "Patience: 188 / 200\n", "Last checkpoint at: Epoch 470 \n", "\n", "Epoch number 475\n", "Cost: 31.123574119695324 e4\n", "Patience: 188 / 200\n", "Last checkpoint at: Epoch 470 \n", "\n", "Epoch number 480\n", "Cost: 27.53822227068087 e4\n", "Patience: 189 / 200\n", "Last checkpoint at: Epoch 470 \n", "\n", "Epoch number 485\n", "Cost: 26.472763485334657 e4\n", "Patience: 189 / 200\n", "Last checkpoint at: Epoch 470 \n", "\n", "Epoch number 490\n", "Cost: 25.98736776990142 e4\n", "Patience: 190 / 200\n", "Last checkpoint at: Epoch 490 \n", "\n", "Epoch number 495\n", "Cost: 25.32091308781441 e4\n", "Patience: 191 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 500\n", "Cost: 26.51548171614079 e4\n", "Patience: 191 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 505\n", "Cost: 25.78474184934129 e4\n", "Patience: 191 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 510\n", "Cost: 26.016250708477294 e4\n", "Patience: 191 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 515\n", "Cost: 28.13248825754891 e4\n", "Patience: 191 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 520\n", "Cost: 28.441735156910852 e4\n", "Patience: 191 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 525\n", "Cost: 25.8854781079324 e4\n", "Patience: 193 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 530\n", "Cost: 25.448204473929202 e4\n", "Patience: 193 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 535\n", "Cost: 26.26546668483222 e4\n", "Patience: 193 / 200\n", "Last checkpoint at: Epoch 495 \n", "\n", "Epoch number 540\n", "Cost: 24.608338271525312 e4\n", "Patience: 196 / 200\n", "Last checkpoint at: Epoch 540 \n", "\n", "Epoch number 545\n", "Cost: 25.521852422822665 e4\n", "Patience: 196 / 200\n", "Last checkpoint at: Epoch 540 \n", "\n", "Epoch number 550\n", "Cost: 24.915404786217085 e4\n", "Patience: 198 / 200\n", "Last checkpoint at: Epoch 540 \n", "\n", "Epoch number 555\n", "Cost: 25.868487217404105 e4\n", "Patience: 198 / 200\n", "Last checkpoint at: Epoch 540 \n", "\n", "Epoch number 560\n", "Cost: 27.24954412576366 e4\n", "Patience: 199 / 200\n", "Last checkpoint at: Epoch 540 \n", "\n", "\n", " Early stopping at epoch 565 , difference: 2.3366942843223992e-05\n", "Cost: 0.002444114783739156\n" ] } ], "source": [ "rnn.fit(minibatches, epochs = 5000, print_step=5)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X2QXXd93/H35z7v6vlhMbZkWTKI1nIgdljkJARIgnFEkrFpArEoaZ3WrYcMnpDSJthDxrROM03IDAEGN8FpNG0oRDyFRkOVGMcGppQYtMbGRgbXsmLsRTKWrWdpH+7Dt3+cs7tHV/fuXT0crXb385q5c8/53XPu/s7qaj/39/udc36KCMzMzKZTmO0KmJnZxc9hYWZmPTkszMysJ4eFmZn15LAwM7OeHBZmZtaTw8LMzHpyWJiZWU8OCzMz66k02xU4X1avXh3r16+f7WqYmc0pDz/88IsRMdBru3kTFuvXr2doaGi2q2FmNqdI+sFMtnM3lJmZ9eSwMDOznhwWZmbWk8PCzMx6cliYmVlPDgszM+vJYWFmZj0t+LA4Mdbgw19+kkeePTTbVTEzu2gt+LAYa7T42IN7eGz4yGxXxczsorXgw6JYEAD1ZmuWa2JmdvFa8GFRLiZh0WzFLNfEzOziteDDYqJl0XBYmJl1teDDolRIfgWNpsPCzKybBR8WxYKQoNnymIWZWTcLPiwASgVRdzeUmVlXDguSrigPcJuZdeewIG1Z+NRZM7OuHBZAqSi3LMzMpuGwAIqFAnWfDWVm1lWuYSFpi6QnJe2RdEeH198t6XFJj0r6uqRNafl6SSNp+aOS/izPepYK8tlQZmbTKOX1xpKKwD3AW4BhYJekHRHxRGazT0fEn6Xb3wh8GNiSvvZ0RFyTV/2ySkX5ojwzs2nk2bLYDOyJiL0RMQ5sB27KbhARRzOri4BZ+YtdKsgX5ZmZTSPPsFgDPJdZH07LTiHpPZKeBj4E/FbmpQ2SHpH0NUlvyLGelIo+ddbMbDp5hoU6lJ32Fzki7omIVwDvB34vLd4PrIuIa4H3AZ+WtPS0HyDdJmlI0tCBAwfOuqI+ddbMbHp5hsUwcHlmfS2wb5rttwNvA4iIsYh4KV1+GHgaeFX7DhFxb0QMRsTgwMDAWVfUp86amU0vz7DYBWyUtEFSBdgK7MhuIGljZvWXgKfS8oF0gBxJVwIbgb15VbRYKPh2H2Zm08jtbKiIaEi6HbgPKALbImK3pLuBoYjYAdwu6XqgDhwCbkl3fyNwt6QG0ATeHREH86qrT501M5tebmEBEBE7gZ1tZXdllt/bZb8vAF/Is25ZPhvKzGx6voIbX2dhZtaLw4LkrrMOCzOz7hwWTHRDeczCzKwbhwU+ddbMrBeHBUk3lC/KMzPrzmFBMg+3WxZmZt05LPDZUGZmvTgs8HUWZma9OCxI7jrrloWZWXcOC9KWhW/3YWbWlcOC5GyopruhzMy6cliQDHDX3bIwM+vKYYFPnTUz68VhAZQLPnXWzGw6DguSyY8icOvCzKwLhwXJmAXgM6LMzLpwWJCcOgv4wjwzsy5yDQtJWyQ9KWmPpDs6vP5uSY9LelTS1yVtyrx2Z7rfk5J+Ic96lorJr8HjFmZmneUWFpKKwD3AW4FNwDuzYZD6dES8OiKuAT4EfDjddxOwFbga2AL81/T9cjHVsnA3lJlZJ3m2LDYDeyJib0SMA9uBm7IbRMTRzOoiYOKr/U3A9ogYi4h/BPak75eLYhoWHuA2M+uslON7rwGey6wPA9e1byTpPcD7gArw85l9H2rbd00+1YTy5AC3w8LMrJM8WxbqUHbaX+OIuCciXgG8H/i9M9lX0m2ShiQNHThw4KwrWiykYxYe4DYz6yjPsBgGLs+srwX2TbP9duBtZ7JvRNwbEYMRMTgwMHDWFS371Fkzs2nlGRa7gI2SNkiqkAxY78huIGljZvWXgKfS5R3AVklVSRuAjcC38qroxJiFu6HMzDrLbcwiIhqSbgfuA4rAtojYLeluYCgidgC3S7oeqAOHgFvSfXdL+izwBNAA3hMRzbzqWnI3lJnZtPIc4CYidgI728ruyiy/d5p9/wD4g/xqN2Xy1Fl3Q5mZdeQruIGiz4YyM5uWwwIop91Qvs7CzKwzhwVTA9x1X8FtZtaRw4KpU2fdsjAz68xhQebUWZ8NZWbWkcMCKPuus2Zm03JYkG1ZeMzCzKwThwXZ6yzcsjAz68RhwdTkRx7gNjPrzGHBVMvCp86amXXmsABKPnXWzGxaDgsyF+U5LMzMOnJYkLndh7uhzMw6cljgGwmamfXisMCnzpqZ9eKwYGryIw9wm5l15rBgqmUx3vCYhZlZJw4LoFAQlWKBcQ9wm5l1lGtYSNoi6UlJeyTd0eH190l6QtJjkh6QdEXmtaakR9PHjjzrCVAtFRit5zbNt5nZnJbbHNySisA9wFuAYWCXpB0R8URms0eAwYg4Kek3gQ8BN6evjUTENXnVr121XGDM3VBmZh3l2bLYDOyJiL0RMQ5sB27KbhARX4mIk+nqQ8DaHOszrWqpyFjdYWFm1kmeYbEGeC6zPpyWdXMr8LeZ9ZqkIUkPSXpbpx0k3ZZuM3TgwIFzqmy1XGC04W4oM7NOcuuGAtShrOO5qZJ+HRgE3pQpXhcR+yRdCTwo6fGIePqUN4u4F7gXYHBw8JzOe3XLwsysuzxbFsPA5Zn1tcC+9o0kXQ98ALgxIsYmyiNiX/q8F/gqcG2OdaVWLjDmloWZWUd5hsUuYKOkDZIqwFbglLOaJF0LfIIkKF7IlK+QVE2XVwOvB7ID4+ddteQBbjOzbnLrhoqIhqTbgfuAIrAtInZLuhsYiogdwB8Di4HPSQJ4NiJuBK4CPiGpRRJof9h2FtV5Vy0VOXxyPM8fYWY2Z+U5ZkFE7AR2tpXdlVm+vst+3wBenWfd2tV86qyZWVe+gjtVLRUdFmZmXTgsUr6C28ysO4dFqlZ2y8LMrBuHRapaKjDmloWZWUcOi1RyBbdbFmZmnTgsUrVSkWYraPg25WZmp3FYpKrl5FfhcQszs9M5LFLVUhHAZ0SZmXXgsEjV3LIwM+vKYZGaaFk4LMzMTuewSFVLya/C3VBmZqdzWKRqZbcszMy6cVikJloWvjDPzOx0DovUxKmzvjDPzOx0DovU5AC3WxZmZqdxWKR86qyZWXcOi5QvyjMz625GYSHpkzMp67DNFklPStoj6Y4Or79P0hOSHpP0gKQrMq/dIump9HHLTOp5Lny7DzOz7mbasrg6uyKpCLx2uh3Sbe4B3gpsAt4paVPbZo8AgxHxGuDzwIfSfVcCHwSuAzYDH5S0YoZ1PSu+KM/MrLtpw0LSnZKOAa+RdDR9HANeAP6mx3tvBvZExN6IGAe2AzdlN4iIr0TEyXT1IWBtuvwLwP0RcTAiDgH3A1vO6MjO0OSpsw13Q5mZtZs2LCLiv0TEEuCPI2Jp+lgSEasi4s4e770GeC6zPpyWdXMr8Ldnue85m7qC2y0LM7N2M+2G+pKkRQCSfl3Sh7PjC12oQ1l03FD6dWAQ+OMz2VfSbZKGJA0dOHCgR3WmJymZLc8tCzOz08w0LP4UOCnpx4HfBX4A/GWPfYaByzPra4F97RtJuh74AHBjRIydyb4RcW9EDEbE4MDAwAwPpbtkalW3LMzM2s00LBoRESRjDh+NiI8CS3rsswvYKGmDpAqwFdiR3UDStcAnSILihcxL9wE3SFqRDmzfkJblqlYuumVhZtZBaYbbHZN0J/AvgDekZzqVp9shIhqSbif5I18EtkXEbkl3A0MRsYOk22kx8DlJAM9GxI0RcVDS75MEDsDdEXHwjI/uDFXLblmYmXUy07C4GfjnwL+OiOclrWNqfKGriNgJ7GwruyuzfP00+24Dts2wfudFtVT0qbNmZh3MqBsqIp4HPgUsk/TLwGhE9BqzmHNq5YKv4DYz62CmV3D/GvAt4B3ArwHflPT2PCs2G9yyMDPrbKbdUB8AXjcxCC1pAPh7kquu5w2fOmtm1tlMz4YqtJ2t9NIZ7Dtn1MpFX5RnZtbBTFsWfyfpPuCv0vWbaRu4ng/csjAz62zasJD0SuCSiPgdSb8C/AzJ1dX/QDLgPa8kYeGWhZlZu15dSR8BjgFExF9HxPsi4t+RtCo+knflLrSkG8otCzOzdr3CYn1EPNZeGBFDwPpcajSL3LIwM+usV1jUpnmt73xW5GJQLRd9BbeZWQe9wmKXpH/bXijpVuDhfKo0e2qlAqONJsltsMzMbEKvs6F+G/iipHcxFQ6DQAX4Z3lWbDZUy0UioN4MKqVOd0k3M1uYpg2LiPgR8NOSfg74sbT4f0fEg7nXbBZkZ8urlObdZSRmZmdtRtdZRMRXgK/kXJdZVy0n83CP1lssmW60xsxsgfHX54ypqVV9+qyZWZbDIqMvbVn4Km4zs1M5LDJqmW4oMzOb4rDImGhZjLgbyszsFLmGhaQtkp6UtEfSHR1ef6Okb0tqtM+PIakp6dH0saN93zzUyh6zMDPrZKZ3nT1j6Tzd9wBvAYZJLvDbERFPZDZ7FvgN4D90eIuRiLgmr/p1MtENNTLusDAzy8otLIDNwJ6I2AsgaTtwEzAZFhHxTPraRTFIMDlm4ftDmZmdIs9uqDXAc5n14bRspmqShiQ9JOlt57dqnfVV0rBwy8LM7BR5tiw63S/jTG66tC4i9km6EnhQ0uMR8fQpP0C6DbgNYN26dWdf01Rt4joLnzprZnaKPFsWw8DlmfW1wL6Z7hwR+9LnvcBXgWs7bHNvRAxGxODAwMC51RaPWZiZdZNnWOwCNkraIKkCbAVmdFaTpBWSqunyauD1ZMY68uLrLMzMOsstLCKiAdwO3Ad8D/hsROyWdLekGwEkvU7SMPAO4BOSdqe7XwUMSfoOyT2p/rDtLKpcFAuiUiy4G8rMrE2eYxZExE6SKVizZXdllneRdE+17/cN4NV51q2bWrngbigzsza+grtNrVz0vaHMzNo4LNr0VYpuWZiZtXFYtKmVih7gNjNr47BoU6sUfSNBM7M2Dos2tVLBNxI0M2vjsGjTVyk6LMzM2jgs2njMwszsdA6LNrVywWMWZmZtHBZt3A1lZnY6h0WbaslhYWbWzmHRJmlZeMzCzCzLYdGmVioy3mzRbJ3J1BtmZvObw6JNXyWdAMldUWZmkxwWbabmtHBYmJlNcFi0mZwtz2FhZjbJYdHGs+WZmZ3OYdGmz/Nwm5mdxmHRZnE1mTzw+FhjlmtiZnbxyDUsJG2R9KSkPZLu6PD6GyV9W1JD0tvbXrtF0lPp45Y865nlsDAzO11uYSGpCNwDvBXYBLxT0qa2zZ4FfgP4dNu+K4EPAtcBm4EPSlqRV12zFteSsDjhsDAzm5Rny2IzsCci9kbEOLAduCm7QUQ8ExGPAe2jyb8A3B8RByPiEHA/sCXHuk5aVE3GLI45LMzMJuUZFmuA5zLrw2nZedtX0m2ShiQNHThw4KwrmrWkWgbg+KjDwsxsQp5hoQ5lM72Hxoz2jYh7I2IwIgYHBgbOqHLd1MoFigW5G8rMLCPPsBgGLs+srwX2XYB9z4kkFlWKHuA2M8vIMyx2ARslbZBUAbYCO2a4733ADZJWpAPbN6RlF8SSWplj7oYyM5uUW1hERAO4neSP/PeAz0bEbkl3S7oRQNLrJA0D7wA+IWl3uu9B4PdJAmcXcHdadkEsqhbdDWVmllHK880jYiews63srszyLpIupk77bgO25Vm/bpbUyhwbq8/GjzYzuyj5Cu4OlveVOXTCYWFmNsFh0cHy/gqHT47PdjXMzC4aDosOVi4qc+ikWxZmZhMcFh0s768wUm96AiQzs5TDooMV/RUADrkryswMcFh0tKI/ueWHB7nNzBIOiw5WLEpaFh7kNjNLOCw6mOiGOuiwMDMDHBYdTXZD+YwoMzPAYdHR8rRlcfiEWxZmZuCw6KhSKrC4WnLLwsws5bDoYnl/2afOmpmlHBZdrOivOCzMzFIOiy5WLKq4G8rMLOWw6GJFf9nXWZiZpRwWXazor3DQZ0OZmQEOi66W9ydTqzaardmuipnZrMs1LCRtkfSkpD2S7ujwelXSZ9LXvylpfVq+XtKIpEfTx5/lWc9OVk7c8mPE4xZmZrlNqyqpCNwDvAUYBnZJ2hERT2Q2uxU4FBGvlLQV+CPg5vS1pyPimrzq18vkhXknx1m9uDpb1TAzuyjk2bLYDOyJiL0RMQ5sB25q2+Ym4H+ky58H3ixJOdZpxiZu+XHQd541M8s1LNYAz2XWh9OyjttERAM4AqxKX9sg6RFJX5P0hhzr2ZHntDAzm5JbNxTQqYUQM9xmP7AuIl6S9Frgf0m6OiKOnrKzdBtwG8C6devOQ5WnvGxJ0vX0o6Oj5/V9zczmojxbFsPA5Zn1tcC+bttIKgHLgIMRMRYRLwFExMPA08Cr2n9ARNwbEYMRMTgwMHBeK796cZVKscAPD42c1/c1M5uL8gyLXcBGSRskVYCtwI62bXYAt6TLbwcejIiQNJAOkCPpSmAjsDfHup6mUBCXLa8xfNhhYWaWWzdURDQk3Q7cBxSBbRGxW9LdwFBE7AD+AvikpD3AQZJAAXgjcLekBtAE3h0RB/OqazdrVvS5ZWFmRr5jFkTETmBnW9ldmeVR4B0d9vsC8IU86zYT61b287fffZ6I4CI5ScvMbFb4Cu5p/NOXL+XwyTrPe5DbzBY4h8U0rr5sKQC7f3i0x5ZmZvObw2Iamy5bSqkgHn720GxXxcxsVjksptFfKXHN5cv5xp4XZ7sqZmazymHRw5teNcBjPzzCcwdPznZVzMxmjcOih1957VoAPjv0XI8tzczmL4dFD2uW9/GmVw3w2aHnGGs0Z7s6ZmazwmExA7f+zAZ+dHSMbV9/ZrarYmY2KxwWM/CGjQNcf9UlfPzBp9h/xFd0m9nC47CYobt+eRPNCN79P7/tubnNbMFxWMzQulX9fGzrtXx//1He/qff4NmXfHaUmS0cDoszcMPVL+dT/+Y6Xjoxzps//FXe//nHOOzJkcxsAVBE+3xEc9Pg4GAMDQ1dkJ/13MGT/Pn/2csnH/oBlWKBG65+Oddevpx3DK5lSa18QepgZnY+SHo4IgZ7buewOHu79x3hM7ue428e3ceRkTpLayWuu3IVb3zVAL/6E2vor+R6U18zs3PmsLiAWq3g8R8eYdv//Ue+9Nh+mq3kd3rl6kW8bv1KNm9YyY+tWcYVq/qplYuzUkczs04cFrMkInjw+y9w/xM/4uEfHOLpA8dJswMJLlvWx7qV/axaXGHVogovW1rjsuU1Ll3Wx2XL+hhYUqVWLnj+DDO7IGYaFu4nOc8k8earLuHNV10CJK2O3fuOsvfF4zzz4kmeeekEzx48ya5nDnJyvMmx0cZp71EpFfgnlyzhZUuqLOsvs7yvwvL+MtVSgXKxwIbVi1i1uEK1VORlS6qUimJRpUSh4IAxs3w4LHJWKIhXr13Gq9cu6/j6yfEG+4+Msv/wKPuOjHDg2BgHT4zz/eeP8vzRUb7//DGOjNQ5PnZ6qGRVigUuXV5jUaVEtVygWiqwqFKir1KkWipOllVLRSqlieXkUUlDqFxMlivpc6kgAigWkjAqFUVBUJAoFQqUiqJUEMX0USgk6wWlZZIDzGyeyDUsJG0BPkoyB/d/i4g/bHu9Cvwl8FrgJeDmiHgmfe1O4FaSObh/KyLuy7Ous6W/UuIVA4t5xcDiabdrNFuMN1uM1ls89aNjHBttcPDEOEdG6pwYbzBSb7L/8CgnxxuMNVqM1pvsPzLKSL3JeKPFWKPJWKPFWKPFeKN1gY4uMRUcTAZINkyKabgoDaLscmHyeWp/peXJdlPbZNez27W/Z/v7d96v888oiFNenwjDbF3V9jyxz2nrE/sVQGRfn1ie+lmCZLvM+rHRBrVygUYruGRpjWYriAgmOpazPcwFQb0ZnBhrUG+2uGLVIvYfGWGk3uTK1YspFODwyTrjjRbL+sq8eHyM5f1lFlfLLO0r8cLRMV46Mcaa5f28dHyMsWaLK1b2A1AuFmhFcHSkwcrFlcnP29JamdF6k6OjDS5bVuOlE+NIMLC4ytHRBuWiaLaCxdUSJ8aak19oKqUCx0YblIsFRupNVi+usO/wKK1IjvPoSJ2+cjGp54kxFlVKjDda1JstVi6qcHikzor+CkdH6iyplSgWxNMHTvDyZTWK0indvBHB/iOjvHxpbfLfcqzRpFwonPJFZ7TenPxi1K2LuNkKBIzUm/RXikRwyns0mi2Ufv5mMlXzxBDBxdIlnVtYSCoC9wBvAYaBXZJ2RMQTmc1uBQ5FxCslbQX+CLhZ0iZgK3A1cBnw95JeFREL9k5+pWKBUrFAfwWuu3LVOb1XqxWMN1tpeDSpN4N6+p9tvNmi3gzG6k0a6Ye/GcHx0QbNCFqR7N9oBc1Wsm2zlTxaMVEetFqRbJ8+N1vQilO3PfU5eb2VXY6g1ZpYzpSldciuN5qt07drte3T9v4RTNbh1P06/7xgaj+bfQUxOR7YTaWYXEo23pz6gtRfKdJMP8NFifFmi4KS0IuAeqtFXzlpgUf6OTg+1iAClvWVJ4M3IigURKM58Rmfer9KsUC5KJb2lSc/QxP/h4oSjVaLpbUyAaf8f2n/P1ErF1i1qMpYo8lovUVE0F8tEQEj4w1WL6lSkNh06VLueddP5PjbzrdlsRnYExF7ASRtB24CsmFxE/Af0+XPAx9XEqM3AdsjYgz4R0l70vf7hxzru2AUCqJWKKZnZvm6kDMVaWBkQwam1putgIlAYip0Isjs12E9fe9smHV6bkWyXalYoN5sEQHHx+pTLbDMN9GJxUYrKKffihutFsdGG1y6LGmNHDg2RpD8saw3ky8NtXIx+fKQtkYXVUusWlThpRPjDCyuMtpocujEOAWJevqNeVGlyMGT45STphLHRhv0lYv0V4qcGE/+2AZwYqzBqkUVjo02WFwtcXhknGqpSCuCkXoTIZbUSun+BQ6erFOUqJQKjNab9KV/7I+PNVhSKzFab1FOu0xbEfSVixwZqbNyUYUDx8dotZJv8bVykVYrODHemOxmHW+0qJQKk/9uBYlyURwbbdCKSFt0oq9cpBnBsdEGpYIopyHUaLYopcsRUC6JSrHAyfFm0pqvtyZbsrVykVJBjKZlY40WxWxrW1PduRNlR06OJy3ISpFaKTmTcqSe/C7LxQLHRuu0Aq5Y1Z/75z7PsFgDZCeBGAau67ZNRDQkHQFWpeUPte27pv0HSLoNuA1g3bp1563iZtOZ7FLi4ugeMLsQ8rzdR6f/Se2Nxm7bzGRfIuLeiBiMiMGBgYGzqKKZmc1EnmExDFyeWV8L7Ou2jaQSsAw4OMN9zczsAskzLHYBGyVtkFQhGbDe0bbNDuCWdPntwIORnAKwA9gqqSppA7AR+FaOdTUzs2nkNmaRjkHcDtxHcurstojYLeluYCgidgB/AXwyHcA+SBIopNt9lmQwvAG8ZyGfCWVmNtt8uw8zswVsprf78HwWZmbWk8PCzMx6cliYmVlP82bMQtIB4Afn8BargRfPU3UuJvPxuObjMYGPa66ZL8d1RUT0vFBt3oTFuZI0NJNBnrlmPh7XfDwm8HHNNfP1uLpxN5SZmfXksDAzs54cFlPune0K5GQ+Htd8PCbwcc018/W4OvKYhZmZ9eSWhZmZ9bTgw0LSFklPStoj6Y7Zrs+ZkLRN0guSvpspWynpfklPpc8r0nJJ+lh6nI9JyndarXMg6XJJX5H0PUm7Jb03LZ/TxyapJulbkr6THtd/Sss3SPpmelyfSW+8SXojzc+kx/VNSetns/7TkVSU9IikL6Xr8+GYnpH0uKRHJQ2lZXP6M3guFnRYZKZ+fSuwCXhnOqXrXPHfgS1tZXcAD0TERuCBdB2SY9yYPm4D/vQC1fFsNIB/HxFXAT8JvCf9d5nrxzYG/HxE/DhwDbBF0k+STCf8J+lxHSKZbhgy0w4Df5Jud7F6L/C9zPp8OCaAn4uIazKnyM71z+DZS6aIXJgP4KeA+zLrdwJ3zna9zvAY1gPfzaw/CVyaLl8KPJkufwJ4Z6ftLvYH8Dckc7nPm2MD+oFvk8we+SJQSssnP5Mkd2z+qXS5lG6n2a57h2NZS/KH8+eBL5FMXjanjymt3zPA6rayefMZPNPHgm5Z0Hnq19Omb51jLomI/QDp88vS8jl5rGk3xbXAN5kHx5Z21zwKvADcDzwNHI6IRrpJtu6nTDsMTEw7fLH5CPC7QCtdX8XcPyZIZuf8sqSH0ymcYR58Bs9WnnNwzwUzmr51nphzxyppMfAF4Lcj4qjUdc7rOXNskczLco2k5cAXgas6bZY+X/THJemXgRci4mFJPztR3GHTOXNMGa+PiH2SXgbcL+n702w7l47rrCz0lsV8nL71R5IuBUifX0jL59SxSiqTBMWnIuKv0+J5cWwAEXEY+CrJmMzydFphOLXu3aYdvpi8HrhR0jPAdpKuqI8wt48JgIjYlz6/QBLsm5lHn8EztdDDYiZTv8412alqbyHp758o/5fpWRs/CRyZaE5fbJQ0If4C+F5EfDjz0pw+NkkDaYsCSX3A9SSDwl8hmVYYTj+uTtMOXzQi4s6IWBsR60n+/zwYEe9iDh8TgKRFkpZMLAM3AN9ljn8Gz8lsD5rM9gP4ReD/kfQdf2C263OGdf8rYD9QJ/lmcytJ/+8DwFPp88p0W5Gc+fU08DgwONv1n+a4foakCf8Y8Gj6+MW5fmzAa4BH0uP6LnBXWn4lyRzze4DPAdW0vJau70lfv3K2j6HH8f0s8KX5cExp/b+TPnZP/G2Y65/Bc3n4Cm4zM+tpoXdDmZnZDDgszMysJ4eFmZn15LAwM7OeHBZmZtaTw8LsDEhqpnchnXictzsVS1qvzB2EzS4mC/12H2ZnaiQirpntSphdaG5ZmJ0H6dwHf5TOV/EtSa9My6+Q9EA6x8EDktal5ZdI+mI6t8V3JP10+lZFSX+eznfx5fRKb7NZ57AwOzN9bd1QN2deOxoRm4GPk9wfiXT5LyPiNcCngI+l5R8DvhbJ3BY/QXKVMCTzIdwTEVcDh4Ffzfl4zGbEV3CbnQFJxyNicYfyZ0gmNtqb3gRoJxbiAAAA4ElEQVTx+YhYJelFknkN6mn5/ohYLekAsDYixjLvsR64P5KJdZD0fqAcEf85/yMzm55bFmbnT3RZ7rZNJ2OZ5SYeV7SLhMPC7Py5OfP8D+nyN0juxgrwLuDr6fIDwG/C5IRISy9UJc3Ohr+1mJ2ZvnSmuwl/FxETp89WJX2T5EvYO9Oy3wK2Sfod4ADwr9Ly9wL3SrqVpAXxmyR3EDa7KHnMwuw8SMcsBiPixdmui1ke3A1lZmY9uWVhZmY9uWVhZmY9OSzMzKwnh4WZmfXksDAzs54cFmZm1pPDwszMevr/9+FuGHi0DGAAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(rnn.loss_list)\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Cost\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "#save in a folder that describes the model\n", "folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", "rnn.save(folder)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm_leaky_relu_5l_[50,40,30,20,10]c/rnn_basic\n" ] } ], "source": [ "#folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", "#rnn.load(folder)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "###test_input.shape###" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "#Here I predict based on my test set\n", "\n", "test_pred = rnn.predict(test_input)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "operands could not be broadcast together with shapes (469,21) (24,) (469,21) ", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m<ipython-input-41-1a19da3ab328>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;31m#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmin_max_scaler\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_input\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;31m#print(min_max_scaler_inv(test_pred)-min_max_scaler_inv(test_target))\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m<ipython-input-8-ab3cbe9c0a7e>\u001b[0m in \u001b[0;36mmin_max_scaler\u001b[1;34m(arr, min_max_scalor)\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 15\u001b[1;33m \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mreshapor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmin_max_scalor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mreshapor_inv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 16\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmin_max_scalor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\sklearn\\preprocessing\\data.py\u001b[0m in \u001b[0;36mtransform\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 367\u001b[0m \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mFLOAT_DTYPES\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 368\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 369\u001b[1;33m \u001b[0mX\u001b[0m \u001b[1;33m*=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mscale_\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 370\u001b[0m \u001b[0mX\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmin_\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 371\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mValueError\u001b[0m: operands could not be broadcast together with shapes (469,21) (24,) (469,21) " ] } ], "source": [ "#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n", "min_max_scaler(test_input)\n", "\n", "\n", "#print(min_max_scaler_inv(test_pred)-min_max_scaler_inv(test_target))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Here I evaluate my model on the test set based on mean_squared_error\n", "\n", "rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }